Created
August 5, 2020 09:44
-
-
Save netsatsawat/e6cca4c9233b4dc68db3ddd63d2b0e62 to your computer and use it in GitHub Desktop.
Snippet of Stochastic gradient descent
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def _iter(X, y, | |
| batch_size: int=1): | |
| n_observations = X.shape[0] | |
| idx = list(range(n_observations)) | |
| random.shuffle(idx) | |
| for batch_id, i in enumerate(range(0, n_observations, batch_size)): | |
| _pos = np.array(idx[i: min(i + batch_size, n_observations)]) | |
| yield batch_id, X.take(_pos, axis=0), y.take(_pos) | |
| def get_pred(X, theta): | |
| '''use the weights to predict the yhat''' | |
| return np.dot(X, theta) | |
| def grad_loss(X, y, theta): | |
| ''' | |
| Compute gradient based using MSE | |
| https://math.stackexchange.com/questions/1962877/compute-the-gradient-of-mean-square-error | |
| ''' | |
| y_pred = get_pred(X, theta) | |
| error = y_pred - y | |
| loss_gradient = (np.dot(np.transpose(X), error))/(len(X)) | |
| return loss_gradient | |
| def _sgd_regressor(X, y, learning_rate, n_epochs, batch_size=1): | |
| mse_log = [] | |
| theta_log = [] | |
| total_loss_log = [] | |
| np.random.seed(SEED) | |
| theta = np.random.rand(len(X[0])) | |
| for i in range(n_epochs+1): | |
| total_error = 0 | |
| for batch_id, data, label in _iter(X, y, batch_size): | |
| grad_loss_= grad_loss(data, label, theta) | |
| theta = theta - learning_rate * grad_loss_ | |
| y_pred = get_pred(X, theta) | |
| _mse = mean_squared_error(y, y_pred) | |
| mse_log.append(_mse) | |
| theta_log.append(theta) | |
| if i % 100 == 0: | |
| print(f'Epoch: {i} | MSE: {_mse}') | |
| return theta, theta_log, mse_log | |
| theta, _, mse_ = _sgd_regressor(X_, y, learning_rate=learning_rate, n_epochs=n_epochs, batch_size=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment