import numpy as np def linear_regression_1d(x, y): """Solves 1 dimensional regression problem: provides w=(w_1, ..., w_n) such that: w* = min_{w} sum_{i=1}{m}[(<w, x_i> - y_i)**2] L(w) = sum_{i=1}{m}[(<w, x_i> - y_i)**2] <w, x_i> = w_1*x_{i1} + w_2*x_{i2} + ... w_n*x_{in} Solutions: Algo #1: exact analytical solution L(w) is a convex function (can you prove it?*) => to find it's global minimum we need to find it's stationary point: dL/dw_j = 0 , j=1,..,n dL/dw_j = \sum_{i=1}{n}{ (2*x_ij*(<w, x_i> - y_i) ) = 0 <=> \sum_{i=1}{n}{x_ij*(<w, x_i>} = \sum_{i=1}{n}{x_ij*y_i}, j=1,..,n X^T*X*w = X^T*y <=> solving system of linear equations: A*w = b soultion to a system: if there's A^(-1), w = A^(-1) * b, i.e. w = (X^T*X)^(-1)*X^T*y if A^(-1) do not exist, we can find pseudoinverse of A Complexity of a solution: Space: O(mxn), time: (nxm) x (mxn) -> O(m*n^2) nxn matrix inverse -> O(n^3) if m>n => time: O(m*n^2) Supplementary ------------- Convex function definition: f(x) is convex, if the line segment between any two points on the graph of the function lies above or on the graph, i.e. f(t*x_1 + (1-t)*x_2) <= t*f(x_1) + (1-t)*f(x_2) for any t in (0, 1), x_1, x_2 """ w = np.linalg.pinv(x.T @ x)@x.T@y return w #DEMO from sklearn.datasets import load_boston import matplotlib.pyplot as plt %matplotlib inline # Artificial data: we set the regression task x = np.arange(0, 10, 0.1) x = np.vstack([np.ones(x.shape), x]).T y = x@np.asarray([[1], [2]]) + np.random.normal(size=[x.shape[0], 1], scale=1.0) w = linear_regression_1d(x, y) xr = np.arange(0, 10, 0.1) xr = np.vstack([np.ones(xr.shape), xr]).T yr = xr@w plt.title("Artificial data") plt.scatter(xr[:, 1], yr, s=4) plt.scatter(x[:, 1], y) plt.show() # Regressing 1 value of Boston housing data ds = load_boston() x = ds['data'][:, 12] x = np.vstack([np.ones(x.shape), x]).T y = ds['target'] w = linear_regression_1d(x, y) xr = np.arange(0, 10, 0.1) xr = np.vstack([np.ones(xr.shape), xr]).T yr = xr@w plt.title("Boston housing data") plt.scatter(xr[:, 1], yr, s=4) plt.scatter(x[:, 1], y) plt.show()