Last active
September 6, 2020 19:14
-
-
Save himanshurawlani/41c1ab8ff4e60ea3fa0f83616e72a2f8 to your computer and use it in GitHub Desktop.
An example script to initialize trainable for Ray Tune and start hyperparameter tuning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Trainable: | |
def __init__(self, train_dir, val_dir, snapshot_dir, final_run=False): | |
# Initializing state variables for the run | |
self.train_dir = train_dir | |
self.val_dir = val_dir | |
self.final_run = final_run | |
self.snapshot_dir = snapshot_dir | |
def train(self, config, reporter=None): | |
# If you get out of memory error try reducing the maximum batch size | |
train_generator = Generator(self.train_dir, config['batch_size']) | |
val_generator = Generator(self.val_dir, config['batch_size']) | |
# Create FCN model | |
model = FCN_model(config, len_classes=len(train_generator.classes)) | |
# Compile model with losses and metrics | |
model.compile(optimizer=tf.keras.optimizers.Nadam(lr=config['lr']), | |
loss='categorical_crossentropy', | |
metrics=['accuracy']) | |
# Create callbacks to be used during model training | |
callbacks = create_callbacks(self.final_run, self.snapshot_dir) | |
logger.info("Starting model training") | |
# Start model training | |
history = model.fit(train_generator, | |
steps_per_epoch=len(train_generator), | |
epochs=100, | |
callbacks=callbacks, | |
validation_data=val_generator, | |
validation_steps=len(val_generator) | |
) | |
return history | |
logger.info("Initializing ray Trainable") | |
# Initialize Trainable for hyperparameter tuning | |
trainer = Trainable(args.train_dir, args.val_dir, args.snapshot_dir, final_run=False) | |
logger.info("Starting hyperparameter tuning") | |
analysis = tune.run(trainer.train, | |
verbose=1, | |
num_samples=num_samples, | |
search_alg=search_alg, | |
scheduler=scheduler, | |
raise_on_failed_trial=False, | |
resources_per_trial={"cpu": 16, "gpu": 2} | |
) | |
best_config = analysis.get_best_config(metric="val_loss", mode='min') | |
logger.info(f'Best config: {best_config}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment