Skip to content

Instantly share code, notes, and snippets.

@erap129
Created September 27, 2021 14:26
Show Gist options
  • Select an option

  • Save erap129/fe592046aabe17ebc1a5dcdcf37f1159 to your computer and use it in GitHub Desktop.

Select an option

Save erap129/fe592046aabe17ebc1a5dcdcf37f1159 to your computer and use it in GitHub Desktop.
NASA RUL project - lstm training
checkpoint_callback = ModelCheckpoint(
dirpath='checkpoints',
filename='best-checkpoint',
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min'
)
logger = TensorBoardLogger('lightning_logs', name='RUL')
trainer = pl.Trainer(
logger=logger,
callbacks=[checkpoint_callback],
max_epochs=n_epochs,
gpus=1,
progress_bar_refresh_rate=30
)
trainer.fit(model, data_module)
trained_model = RULPredictor.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path,
n_features=len(feature_columns)
)
trained_model.freeze()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment