Last active
November 16, 2015 00:07
-
-
Save thekensta/0209362fe642a4201c0b to your computer and use it in GitHub Desktop.
Summary of least squares in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Quick reminder of least squares calculations in python | |
import numpy as np | |
def least_sq_numpy(x, y): | |
"""Calculate y = mx + c from x, y returning m, c using numpy.""" | |
A = np.vstack([x, np.ones(x.size)]).T | |
fit = np.linalg.lstsq(A, y) | |
return fit[0] | |
def least_sq_mat(x, y): | |
X = np.column_stack((x, np.ones_like(x))) | |
b = np.solve(np.dot(X.T, X), np.dot(X.T, y)) | |
return b[0], b[1] | |
def least_sq(x, y): | |
"""Calculate y = mx + c from x and y, returning (m, c).""" | |
xm = x - x.mean() | |
ym = y - y.mean() | |
m = (ym * xm).sum() / (xm * xm).sum() | |
c = y.mean() - b1 * x.mean() | |
return m, c | |
def r2(y, yhat): | |
"""Calculate R-squared. """ | |
y1 = yhat - y.mean() | |
y2 = y - y.mean() | |
return (y1 * y1).sum() / (y2 * y2).sum() | |
def stderr(y, yhat): | |
N = y.size | |
assert(N > 2) | |
e = y - yhat | |
return np.sqrt((e * e).sum() / (N - 2)) | |
x = np.linspace(0, 10, num=50) | |
y = x + np.random.normal(scale=5.0, size=50) | |
m, c = least_sq(x, y) | |
print("m", m, "c", c) | |
# m 1.56761305698 c -2.23354134041 | |
yhat = m * x + c | |
print("r-sq", r2(y, yhat)) | |
print("stderr", stderr(y, yhat)) | |
# r-sq 0.406452211459 | |
# stderr 5.69406972454 | |
# Numpy implementation | |
print(least_sq_numpy(x, y)) | |
# [ 1.56761306 -2.23354134] | |
# If running in IPython, validate with lm | |
%load_ext rpy2.ipython | |
%R -i x,y X <- x; Y <- y; mdl <- lm(Y ~ X); print(summary(mdl)) | |
# Call: | |
# lm(formula = Y ~ X) | |
# Residuals: | |
# Min 1Q Median 3Q Max | |
# -13.5837 -2.4414 -0.1266 2.6261 14.0993 | |
# Coefficients: | |
# Estimate Std. Error t value Pr(>|t|) | |
# (Intercept) -2.2335 1.5867 -1.408 0.166 | |
# X 1.5676 0.2734 5.733 6.39e-07 *** | |
# --- | |
# Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 | |
# Residual standard error: 5.694 on 48 degrees of freedom | |
# Multiple R-squared: 0.4065, Adjusted R-squared: 0.3941 | |
# F-statistic: 32.87 on 1 and 48 DF, p-value: 6.391e-07 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment