Created
August 5, 2020 10:13
-
-
Save netsatsawat/07a5c340787128cee02e966680f1846b to your computer and use it in GitHub Desktop.
Snippet of mini-batch using SGDRegressor in sklearn
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _get_chunk(X, y, chunkrows): | |
X_chunk, y_chunk = X[chunkrows], y[chunkrows] | |
return X_chunk, y_chunk | |
def _iter_minibatch(X, y, chunk_size): | |
''' | |
Construct minibatch generator | |
''' | |
_start = 0 | |
_total_observation = X.shape[0] | |
while _start < _total_observation: | |
chunkrows = range(_start, _start + chunk_size) | |
X_chunk, y_chunk = _get_chunk(X, y, chunkrows) | |
yield X_chunk, y_chunk | |
_start += chunk_size | |
mini_batch_generator = _iter_minibatch(X=X, y=y, chunk_size=100) | |
SGD_model = SGDRegressor(fit_intercept=True, random_state=SEED, eta0=learning_rate, | |
learning_rate='constant', max_iter=n_epochs) | |
for X_chunk, y_chunk in mini_batch_generator: | |
SGD_model.partial_fit(X_chunk, y_chunk) | |
y_pred = SGD_model.predict(X) | |
print(SGD_model) | |
print(f'Intercept: {SGD_model.intercept_}, weights: {SGD_model.coef_}') | |
_ = print_regress_metric(y, y_pred) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment