Last active
January 8, 2019 02:36
-
-
Save RyotaBannai/786e86e9ddd40f27c8b49cfe2806be84 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib | |
import matplotlib.pyplot as plt | |
plt.style.use('ggplot') | |
fig, (ax1, ax2) = plt.subplots(ncols=2,figsize=(16,6)) | |
plt.xlim((0, 10)) | |
plt.ylim((0, 7)) | |
plt.tight_layout(w_pad=1.5) | |
#red line | |
points = np.array([[1,1], [9,5]]) | |
slope = (points[1,1] - points[0,1])/(points[1,0] - points[0,0]) | |
intercept = points[0,1]-slope*points[0,0] | |
#data points | |
xy = np.array([[2,2.5,4,5,6,7,7.5], | |
[1,3.5,4,1.5,3,6,3.5]]).T | |
#calculate projection points | |
y_fun = lambda x, s=slope, i=intercept: s*x +i | |
def y_fun_tls (x, y, slope, intercept): | |
s = (-1./slope) | |
i = y - s*x | |
#solve linear equation: -ax+y=b | |
a = np.array([[-s, 1.], [-slope, 1.]]) | |
b = np.array([i, intercept]) | |
xy = np.linalg.solve(a, b) | |
return xy | |
#preprocess data for plot | |
ln = np.linspace(xy[:,0].min()-1, xy[:,0].max()+1, 10) | |
line = y_fun(ln) | |
y_ols_proj = [y_fun(x) for x in xy[:,0]] | |
xy_tls_proj = np.array([y_fun_tls(x,y,slope,intercept) for x,y in xy]) | |
#plot | |
for ax_ in (ax1, ax2): | |
ax_.plot(ln, line, c='r', alpha=.6) | |
ax_.scatter(xy[:,0], xy[:,1], s=70, alpha=.5, c='b') | |
#projection points | |
proj_prop = dict(c='g', s=50, alpha=.5) | |
ax1.scatter(xy[:,0], y_ols_proj, **proj_prop) | |
ax2.scatter(xy_tls_proj[:,0], xy_tls_proj[:,1], **proj_prop) | |
#residual grey lines for ols and tls | |
line_prop = dict(c='gray', alpha=.6, linewidth=3.) | |
for x, y_start, y_end in zip(xy[:,0], xy[:,1], y_ols_proj): | |
ax1.plot(np.full(2,x), [y_start, y_end], **line_prop) | |
for x_start, x_end, y_start, y_end in zip(xy[:,0], xy_tls_proj[:,0], xy[:,1], xy_tls_proj[:,1]): | |
ax2.plot([x_start, x_end], [y_start, y_end], **line_prop) | |
for ax_, labelx, labely, title in zip([ax1, ax2], np.full(2,'X'), ['Y',''], ['OLS','TLS']): | |
ax_.set_xlabel(labelx) | |
ax_.set_ylabel(labely) | |
ax_.set_title(title) | |
plt.axis('equal') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment