Skip to content

Instantly share code, notes, and snippets.

@phil8192
Created April 3, 2020 20:27
Show Gist options
  • Select an option

  • Save phil8192/89c9d99ac1404a79b1dfc03e628e394b to your computer and use it in GitHub Desktop.

Select an option

Save phil8192/89c9d99ac1404a79b1dfc03e628e394b to your computer and use it in GitHub Desktop.
def train_model(intercept_init, coef_init, X, y, epochs, lr, batch_size=None, randomise=True):
if batch_size is None or batch_size <= 0:
batch_size = X.shape[0]
classes = np.unique(y)
model = linear_model.SGDClassifier(loss='log', learning_rate='constant', eta0=lr, verbose=0)
set_weights(intercept_init, coef_init, classes, model)
batch_train(model, X, y, classes, epochs, batch_size, randomise)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment