Skip to content

Instantly share code, notes, and snippets.

@sergeyprokudin
Last active January 7, 2019 22:18
Show Gist options
  • Save sergeyprokudin/0979e6d1c0cf7e017d6cc62d698f9ca7 to your computer and use it in GitHub Desktop.
Save sergeyprokudin/0979e6d1c0cf7e017d6cc62d698f9ca7 to your computer and use it in GitHub Desktop.
The code for simple linear regression (exact solution)
import numpy as np
def linear_regression_1d(x, y):
"""Solves 1 dimensional regression problem: provides w=(w_1, ..., w_n) such that:
w* = min_{w} sum_{i=1}{m}[(<w, x_i> - y_i)**2]
L(w) = sum_{i=1}{m}[(<w, x_i> - y_i)**2]
<w, x_i> = w_1*x_{i1} + w_2*x_{i2} + ... w_n*x_{in}
Solutions:
Algo #1: exact analytical solution
L(w) is a convex function (can you prove it?*) => to find it's global minimum we need
to find it's stationary point:
dL/dw_j = 0 , j=1,..,n
dL/dw_j = \sum_{i=1}{n}{ (2*x_ij*(<w, x_i> - y_i) ) = 0 <=>
\sum_{i=1}{n}{x_ij*(<w, x_i>} = \sum_{i=1}{n}{x_ij*y_i}, j=1,..,n
X^T*X*w = X^T*y
<=> solving system of linear equations:
A*w = b
soultion to a system: if there's A^(-1),
w = A^(-1) * b, i.e. w = (X^T*X)^(-1)*X^T*y
if A^(-1) do not exist, we can find pseudoinverse of A
Complexity of a solution:
Space: O(mxn), time:
(nxm) x (mxn) -> O(m*n^2)
nxn matrix inverse -> O(n^3)
if m>n => time: O(m*n^2)
Supplementary
-------------
Convex function definition: f(x) is convex, if the line segment between
any two points on the graph of the function lies above or on the graph, i.e.
f(t*x_1 + (1-t)*x_2) <= t*f(x_1) + (1-t)*f(x_2)
for any t in (0, 1), x_1, x_2
"""
w = np.linalg.pinv(x.T @ x)@x.T@y
return w
#DEMO
from sklearn.datasets import load_boston
import matplotlib.pyplot as plt
%matplotlib inline
# Artificial data: we set the regression task
x = np.arange(0, 10, 0.1)
x = np.vstack([np.ones(x.shape), x]).T
y = [email protected]([[1], [2]]) + np.random.normal(size=[x.shape[0], 1], scale=1.0)
w = linear_regression_1d(x, y)
xr = np.arange(0, 10, 0.1)
xr = np.vstack([np.ones(xr.shape), xr]).T
yr = xr@w
plt.title("Artificial data")
plt.scatter(xr[:, 1], yr, s=4)
plt.scatter(x[:, 1], y)
plt.show()
# Regressing 1 value of Boston housing data
ds = load_boston()
x = ds['data'][:, 12]
x = np.vstack([np.ones(x.shape), x]).T
y = ds['target']
w = linear_regression_1d(x, y)
xr = np.arange(0, 10, 0.1)
xr = np.vstack([np.ones(xr.shape), xr]).T
yr = xr@w
plt.title("Boston housing data")
plt.scatter(xr[:, 1], yr, s=4)
plt.scatter(x[:, 1], y)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment