Skip to content

Instantly share code, notes, and snippets.

@Erlemar
Created May 3, 2020 09:22
Show Gist options
  • Save Erlemar/43bf08fb7114aca99b46fb04bab4b41f to your computer and use it in GitHub Desktop.
Save Erlemar/43bf08fb7114aca99b46fb04bab4b41f to your computer and use it in GitHub Desktop.
Run experiment with hydra, pytorch-lightning, comet
import os
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning.loggers import CometLogger, TensorBoardLogger
from src.lightning_classes.lightning_nbeats import LitM5NBeats
from src.utils.utils import set_seed
def run(cfg: DictConfig):
set_seed(cfg.training.seed)
model = LitM5NBeats(cfg=cfg)
early_stopping = pl.callbacks.EarlyStopping(**cfg.callbacks.early_stopping.params)
model_checkpoint = pl.callbacks.ModelCheckpoint(**cfg.callbacks.model_checkpoint.params)
tb_logger = TensorBoardLogger(cfg.general.save_dir)
comet_logger = CometLogger(save_dir=cfg.general.save_dir,
workspace=cfg.general.workspace,
project_name=cfg.general.project_name,
api_key=cfg.private.comet_api,
experiment_name=os.getcwd().split('\\')[-1])
trainer = pl.Trainer(gpus=2, logger=[tb_logger, comet_logger],
distributed_backend='dp',
max_epochs=cfg.training.epochs,
early_stop_callback=early_stopping,
checkpoint_callback=model_checkpoint,
nb_sanity_val_steps=0)
trainer.fit(model)
@hydra.main(config_path="conf/config.yaml")
def run_model(cfg: DictConfig) -> None:
print(cfg.pretty())
run(cfg)
if __name__ == "__main__":
run_model()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment