Skip to content

Instantly share code, notes, and snippets.

@rohithteja
Last active June 21, 2022 11:37
Show Gist options
  • Select an option

  • Save rohithteja/821935e7c3c690bb424ab8e1edc59dec to your computer and use it in GitHub Desktop.

Select an option

Save rohithteja/821935e7c3c690bb424ab8e1edc59dec to your computer and use it in GitHub Desktop.
Optuna LSTM
import optuna
from optuna.trial import TrialState
from sklearn.metrics import accuracy_score
def objective(trial):
optimizer_name = trial.suggest_categorical("optimizer", ["adam", "SGD", "RMSprop", "Adadelta"])
epochs = trial.suggest_int("epochs", 5, 15,step=5, log=False)
batchsize = trial.suggest_int("batchsize", 8, 40,step=16, log=False)
history, model = lstm(optimizer_name,epochs,batchsize)
val_acc = model.evaluate(X_val,y_val)[1]
weights = model.get_weights()
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
trial.set_user_attr(key="best_model_weights", value=weights)
return val_acc
def callback(study, trial):
if study.best_trial.number == trial.number:
study.set_user_attr(key="best_model_weights",
value=trial.user_attrs["best_model_weights"])
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20, timeout=None, callbacks=[callback])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment