Created
February 16, 2012 21:51
-
-
Save ameerkat/1848143 to your computer and use it in GitHub Desktop.
Univariate Batch Gradient Descent
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Univariate Batch Gradient Descent - Linear Regression | |
| # Derived from Andrew Ng's Lectures on Machine Learning (Lecture 2) | |
| # Ameer Ayoub <ameer.ayoub@gmail.com> 2/16/2012 | |
| import math | |
| import matplotlib.pyplot as plt | |
| def pd_sqr_err0(theta0, theta1, x, y): | |
| m = len(x) | |
| return sum([(1.0/m)*((theta0 + theta1 * x[i]) - y[i]) for i in range(m)]) | |
| def pd_sqr_err1(theta0, theta1, x, y): | |
| m = len(x) | |
| return sum([(1.0/m)*((theta0 + theta1 * x[i]) - y[i])*x[i] for i in range(m)]) | |
| def ubgd(x, y, theta0 = 0, theta1 = 0, eps = 10**(-6), alpha = 1, max_iter = 1000000): | |
| iterations = 0 | |
| max_diff = 1000 * eps | |
| # repeat until convergence of | |
| while(max_diff > eps and iterations < max_iter): | |
| tmp_theta0 = theta0 - (alpha * pd_sqr_err0(theta0, theta1, x, y)) | |
| tmp_theta1 = theta1 - (alpha * pd_sqr_err1(theta0, theta1, x, y)) | |
| max_diff = max(math.fabs(tmp_theta0 - theta0), math.fabs(tmp_theta1 - theta1)) | |
| #print tmp_theta0, tmp_theta1, max_diff | |
| theta0 = tmp_theta0 | |
| theta1 = tmp_theta1 | |
| iterations += 1 | |
| if math.isnan(max_diff): | |
| print "divergence, try lowering alpha." | |
| if iterations == max_iter: | |
| print "max iterations reached." | |
| else: | |
| print "{0} iterations completed".format(iterations) | |
| return theta0, theta1 | |
| if __name__ == "__main__": | |
| # house pricing example as in the lecture | |
| x = [2104, 1416, 1534, 852] | |
| y = [460, 232, 315, 178] | |
| h = ubgd(x, y, alpha = 10**-9) | |
| print "h(x) = {0} + {1}x".format(*h) | |
| plt.plot(x, y, 'ko') | |
| straight_line = [h[0]+(h[1]*x) for x in range(int(max(x)*1.1))] | |
| #print straight_line | |
| plt.plot(straight_line, '-') | |
| #plt.axis([0, max(x)*1.1, 0, max(y)*1.1]) | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment