Linear Regression

Using Python and Scipy


This page can be downloaded as interactive jupyter notebook


In this notebook, we use the scipy module to perform a linear regression in Python. Assuming we have a set of 2D points $(x,y)$ we want to regress the parameters $a, b$ of the linear equation $y = a\cdot x + b$ such that the mean squared error of $y$ w.r.t. all samples is minimal.


Preparation

In order to implement the regression, we first import the required Python modules:

import numpy as np                          # Used for numerical computations
from scipy.stats import linregress          # Implementation of the regression 
import matplotlib.pyplot as plt             # Plotting library  

# This is to set the size of the plots in the notebook
plt.rcParams['figure.figsize'] = [6, 6]    

Creating a Toy Dataset

Next, we will create a toy dataset. It will contain noisy samples drawn from a known line.

ar, br = 0.873, 1.243                          # Ground truth parameters

np.random.seed(0)
x = np.random.random(200)*10                 # Drawing 500 points in the range [0,10]
y = ar*x + br + np.random.randn(200)           # We compute the y coordinates with a additional white noise 

plt.scatter(x, y, c='black', marker='o', label='Data samples')
plt.legend(); plt.show()

png

Performing the regression

The linear regression using scipy can be done in one line. The function will return:

  • a: slope
  • b: intercept
  • r: correlation coefficient
  • p: p-value for a hypothesis test
  • s: standard error of the estimated gradient
a,b,r,p,s = linregress(x,y)
print('slope:', a)
print('intercept:', b)
print('correlation coefficient:', r)
print('p-value:', p)
print('standard error of the estimated gradient:', s)
slope: 0.8291955978325258
intercept: 1.3499418955427611
correlation coefficient: 0.9265832004752252
p-value: 4.938058020631706e-86
standard error of the estimated gradient: 0.023918369974605936

We can see that the correlation coefficient is close to 1.0 which means that the correlation is very high and it is thus very likely, that a linear relation exists. For visualization we can draw the real and estimated lines:

ar, br = 0.873, 1.243                          # Ground truth parameters

plt.scatter(x, y, c='black', marker='o', label='Data samples')

x_line = np.array([-1.0, 11.0])
real_y = ar*x_line + br
esti_y = a*x_line + b

plt.plot(x_line, real_y,c='red',lw=3, label='Real line')
plt.plot(x_line, esti_y,c='green',lw=3, label='Estimated line')

plt.legend(); plt.show()

png

Author: Dennis Wittich
Last modified: 06.05.2019