Skip to content

Instantly share code, notes, and snippets.

@mikmart
Created January 19, 2016 17:08
Show Gist options
  • Save mikmart/952b7ab55111e25d6d3e to your computer and use it in GitHub Desktop.
Save mikmart/952b7ab55111e25d6d3e to your computer and use it in GitHub Desktop.
Simple OLS with Python
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
'''
Simple OLS with Python
Fitting a simple OLS regression model using linear algebra with Python,
following the first Matlab programming excercise for Machine Learning.
Author:
Mikko Marttila
Date:
19 Jan 2016
Required modules:
numpy, scipy, matplotlib (for visualization)
Module installation (on Ubuntu):
sudo apt-get install -y python3-numpy python3-scipy python3-matplotlib
'''
import numpy.matlib as npm
import scipy.linalg as spl
import scipy.io
import matplotlib.pyplot as plt
def head(matrix, n = 10):
''' Returns a view of the n first rows of a matrix. '''
head = matrix[0:n, :]
return head
# Load data from matlab data format
data = scipy.io.loadmat("lr_2D_data.mat")
y = npm.matrix(data["prices"])
x = npm.matrix(data["sizes"])
print("\nHead of the data:")
print(head(npm.hstack([y, x])))
# Assemble the model matrix for simple linear regression: y = w0 + w1 * x
ones = npm.ones(x.shape)
X = npm.hstack([ones, x])
print("\nHead of the model matrix:")
print(head(X))
# OLS estimator for regression coefficients
w = spl.pinv(X.T * X) * X.T * y
print("\nRegression coefficients:"); print(w)
# Fitted values for prices at observed sizes
fit = X * w
print("\nSome fitted values:")
print(head(npm.hstack([fit, X])))
# Coefficient of determination
TSS = sum(npm.square(y - npm.mean(y))) # Total sum of squares
RSS = sum(npm.square(y - fit)) # Residual sum of squares
print("\nR² fit statistic: ", end = "")
R2 = 1 - RSS/TSS; print(round(float(R2), 3))
# Predicted values for new sizes
x_new = npm.matrix(npm.arange(1, 131)).T # Sizes to predict at
X_new = npm.hstack([npm.ones(x_new.shape), x_new]) # Prediction model matrix
pred = X_new * w
# Visualization
plt.plot(x, y, 'bo') # Scatterplot of data
plt.plot(x_new, pred, 'r') # Predicted line from OLS
plt.ylabel("Price")
plt.xlabel("Size")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment