Created
December 7, 2018 16:26
-
-
Save koshian2/497cf82479c6f9d1d92d19d400355705 to your computer and use it in GitHub Desktop.
Optuna Keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| from tensorflow.keras.applications import InceptionV3, VGG16, MobileNet | |
| from tensorflow.keras.layers import GlobalAveragePooling2D, Dense | |
| from tensorflow.keras.models import Model | |
| from tensorflow.keras.callbacks import History, Callback | |
| import tensorflow.keras.backend as K | |
| from tensorflow.contrib.tpu.python.tpu import keras_support | |
| from keras.utils import to_categorical | |
| from keras.datasets import cifar10 | |
| import numpy as np | |
| import os, pickle, glob, zipfile | |
| import optuna | |
| def create_network(network): | |
| assert network in ["inception", "vgg", "mobilenet"] | |
| # 解像度:IncpetionV3=75-, VGG16=32-, MobileNet=128- なので128にあわせる | |
| if network == "inception": | |
| net = InceptionV3(include_top=False, weights="imagenet", input_shape=(128,128,3)) | |
| elif network == "vgg": | |
| net = VGG16(include_top=False, weights="imagenet", input_shape=(128,128,3)) | |
| elif network == "mobilenet": | |
| net = MobileNet(include_top=False, weights="imagenet", input_shape=(128,128,3)) #128 | |
| # 最後の5レイヤーまでをフリーズ | |
| for layer in net.layers[:-5]: | |
| layer.trainable = False | |
| x = GlobalAveragePooling2D()(net.layers[-1].output) | |
| x = Dense(10, activation="softmax")(x) | |
| return Model(net.inputs, x) | |
| class OptunaCallback(Callback): | |
| def __init__(self, trial): | |
| self.trial = trial | |
| def on_epoch_end(self, epoch, logs): | |
| current_val_error = 1.0 - logs["val_acc"] | |
| self.trial.report(current_val_error, step=epoch) | |
| # 打ち切り判定 | |
| if self.trial.should_prune(epoch): | |
| raise optuna.structs.TrialPruned() | |
| def generator(X, y, batch_size): | |
| while True: | |
| indices = np.arange(X.shape[0]) | |
| np.random.shuffle(indices) | |
| for i in range(X.shape[0]//batch_size): | |
| current_indices = indices[i*batch_size:(i+1)*batch_size] | |
| X_select = X[current_indices] | |
| X_select = X_select.repeat(4, axis=1).repeat(4, axis=2) | |
| X_batch = X_select / 255.0 | |
| y_batch = to_categorical(y[current_indices], 10) | |
| yield X_batch, y_batch | |
| def train(network, optimizer, learning_rate, trial): | |
| batch_size = 1024 | |
| (X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
| train_gen = generator(X_train, y_train, batch_size) | |
| test_gen = generator(X_test, y_test, batch_size) | |
| model = create_network(network) | |
| if optimizer == "sgd": | |
| model.compile(tf.train.GradientDescentOptimizer(learning_rate), "categorical_crossentropy", ["acc"]) | |
| elif optimizer == "momentum": | |
| model.compile(tf.train.MomentumOptimizer(learning_rate, 0.9), "categorical_crossentropy", ["acc"]) | |
| elif optimizer == "rmsprop": | |
| model.compile(tf.train.RMSPropOptimizer(learning_rate), "categorical_crossentropy", ["acc"]) | |
| elif optimizer == "adam": | |
| model.compile(tf.train.AdamOptimizer(learning_rate), "categorical_crossentropy", ["acc"]) | |
| tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"] | |
| tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url) | |
| strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver) | |
| model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) | |
| if not os.path.exists(network): | |
| os.mkdir(network) | |
| hist = History() | |
| truncate = OptunaCallback(trial) | |
| model.fit_generator(train_gen, X_train.shape[0]//batch_size, callbacks=[hist, truncate], | |
| validation_data=test_gen, validation_steps=X_test.shape[0]//batch_size, | |
| epochs=50) | |
| history = hist.history | |
| return history | |
| def optuna_finding(network): | |
| def objective(trial): | |
| # ハイパーパラメータ(オプティマイザーと学習率を調べる) | |
| optimizer = trial.suggest_categorical("optimizer", ["sgd", "momentum", "rmsprop", "adam"]) | |
| learning_rate = trial.suggest_loguniform("learning_rate", 1e-7, 1e0) | |
| K.clear_session() | |
| hist = train(network, optimizer, learning_rate, trial) | |
| return 1.0 - np.max(hist["val_acc"]) | |
| study = optuna.create_study() | |
| study.optimize(objective, n_trials=50) | |
| print(study.best_params) | |
| print(study.best_value) | |
| print(study.best_trial) | |
| trial_df = study.trials_dataframe() | |
| trial_df.to_csv("cifar.csv") | |
| if __name__ == "__main__": | |
| optuna_finding("mobilenet") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment