Created
June 16, 2019 13:38
-
-
Save BrambleXu/4bc854fdf2a45f9eb1fbbae3aad5b291 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # random seed to make sure reimplement | |
| np.random.seed(0) | |
| # the real model line | |
| def g(x): | |
| return 0.1 * (x + x**2 + x**3) | |
| # add noise to the model for faking data | |
| train_x = np.linspace(-2, 2, 8) | |
| train_y = g(train_x) + np.random.randn(len(train_x)) * 0.05 | |
| # # plot | |
| # x = np.linspace(-2, 2, 100) | |
| # plt.plot(train_x, train_y, 'o') | |
| # plt.plot(x, g(x), linestyle='dashed') | |
| # plt.ylim(-1, 2) | |
| # plt.show() | |
| # standardization | |
| mu = train_x.mean() | |
| std = train_x.std() | |
| def standardizer(x): | |
| return (x - mu) / std | |
| std_x = standardizer(train_x) | |
| # get matrix | |
| def to_matrix(x): | |
| return np.vstack([ | |
| np.ones(x.size), | |
| x, | |
| x ** 2, | |
| x ** 3, | |
| x ** 4, | |
| x ** 5, | |
| x ** 6, | |
| x ** 7, | |
| x ** 8, | |
| x ** 9, | |
| x ** 10, | |
| ]).T | |
| mat_x = to_matrix(std_x) | |
| # initialize parameter | |
| theta = np.random.randn(mat_x.shape[1]) | |
| # predict function | |
| def f(x): | |
| return np.dot(x, theta) | |
| # cost function | |
| def E(x, y): | |
| return 0.5 * np.sum((y - f(x))**2) | |
| # initialize error | |
| error = E(mat_x, train_y) | |
| # learning rate | |
| ETA = 1e-4 | |
| # initialize difference between two epochs | |
| diff = 1 | |
| ######## training without regularization ######## | |
| while diff > 1e-6: | |
| # mat_x = (20, 4) | |
| # f(x) - y = (20,) | |
| theta = theta - ETA * (np.dot(f(mat_x) - train_y, mat_x)) | |
| current_error = E(mat_x, train_y) | |
| diff = error - current_error | |
| error = current_error | |
| # save parameters | |
| theta1 = theta | |
| ########## plot the line without regularization ########## | |
| plt.ylim(-1, 2) | |
| plt.plot(std_x, train_y, 'o') | |
| z = standardizer(np.linspace(-2, 2, 100)) | |
| theta = theta1 | |
| plt.plot(z, f(to_matrix(z)), linestyle='dashed') | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment