Last active
January 7, 2019 22:18
-
-
Save sergeyprokudin/0979e6d1c0cf7e017d6cc62d698f9ca7 to your computer and use it in GitHub Desktop.
The code for simple linear regression (exact solution)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def linear_regression_1d(x, y): | |
"""Solves 1 dimensional regression problem: provides w=(w_1, ..., w_n) such that: | |
w* = min_{w} sum_{i=1}{m}[(<w, x_i> - y_i)**2] | |
L(w) = sum_{i=1}{m}[(<w, x_i> - y_i)**2] | |
<w, x_i> = w_1*x_{i1} + w_2*x_{i2} + ... w_n*x_{in} | |
Solutions: | |
Algo #1: exact analytical solution | |
L(w) is a convex function (can you prove it?*) => to find it's global minimum we need | |
to find it's stationary point: | |
dL/dw_j = 0 , j=1,..,n | |
dL/dw_j = \sum_{i=1}{n}{ (2*x_ij*(<w, x_i> - y_i) ) = 0 <=> | |
\sum_{i=1}{n}{x_ij*(<w, x_i>} = \sum_{i=1}{n}{x_ij*y_i}, j=1,..,n | |
X^T*X*w = X^T*y | |
<=> solving system of linear equations: | |
A*w = b | |
soultion to a system: if there's A^(-1), | |
w = A^(-1) * b, i.e. w = (X^T*X)^(-1)*X^T*y | |
if A^(-1) do not exist, we can find pseudoinverse of A | |
Complexity of a solution: | |
Space: O(mxn), time: | |
(nxm) x (mxn) -> O(m*n^2) | |
nxn matrix inverse -> O(n^3) | |
if m>n => time: O(m*n^2) | |
Supplementary | |
------------- | |
Convex function definition: f(x) is convex, if the line segment between | |
any two points on the graph of the function lies above or on the graph, i.e. | |
f(t*x_1 + (1-t)*x_2) <= t*f(x_1) + (1-t)*f(x_2) | |
for any t in (0, 1), x_1, x_2 | |
""" | |
w = np.linalg.pinv(x.T @ x)@x.T@y | |
return w | |
#DEMO | |
from sklearn.datasets import load_boston | |
import matplotlib.pyplot as plt | |
%matplotlib inline | |
# Artificial data: we set the regression task | |
x = np.arange(0, 10, 0.1) | |
x = np.vstack([np.ones(x.shape), x]).T | |
y = [email protected]([[1], [2]]) + np.random.normal(size=[x.shape[0], 1], scale=1.0) | |
w = linear_regression_1d(x, y) | |
xr = np.arange(0, 10, 0.1) | |
xr = np.vstack([np.ones(xr.shape), xr]).T | |
yr = xr@w | |
plt.title("Artificial data") | |
plt.scatter(xr[:, 1], yr, s=4) | |
plt.scatter(x[:, 1], y) | |
plt.show() | |
# Regressing 1 value of Boston housing data | |
ds = load_boston() | |
x = ds['data'][:, 12] | |
x = np.vstack([np.ones(x.shape), x]).T | |
y = ds['target'] | |
w = linear_regression_1d(x, y) | |
xr = np.arange(0, 10, 0.1) | |
xr = np.vstack([np.ones(xr.shape), xr]).T | |
yr = xr@w | |
plt.title("Boston housing data") | |
plt.scatter(xr[:, 1], yr, s=4) | |
plt.scatter(x[:, 1], y) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment