Created
May 3, 2020 09:22
-
-
Save Erlemar/43bf08fb7114aca99b46fb04bab4b41f to your computer and use it in GitHub Desktop.
Run experiment with hydra, pytorch-lightning, comet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import hydra | |
import pytorch_lightning as pl | |
from omegaconf import DictConfig | |
from pytorch_lightning.loggers import CometLogger, TensorBoardLogger | |
from src.lightning_classes.lightning_nbeats import LitM5NBeats | |
from src.utils.utils import set_seed | |
def run(cfg: DictConfig): | |
set_seed(cfg.training.seed) | |
model = LitM5NBeats(cfg=cfg) | |
early_stopping = pl.callbacks.EarlyStopping(**cfg.callbacks.early_stopping.params) | |
model_checkpoint = pl.callbacks.ModelCheckpoint(**cfg.callbacks.model_checkpoint.params) | |
tb_logger = TensorBoardLogger(cfg.general.save_dir) | |
comet_logger = CometLogger(save_dir=cfg.general.save_dir, | |
workspace=cfg.general.workspace, | |
project_name=cfg.general.project_name, | |
api_key=cfg.private.comet_api, | |
experiment_name=os.getcwd().split('\\')[-1]) | |
trainer = pl.Trainer(gpus=2, logger=[tb_logger, comet_logger], | |
distributed_backend='dp', | |
max_epochs=cfg.training.epochs, | |
early_stop_callback=early_stopping, | |
checkpoint_callback=model_checkpoint, | |
nb_sanity_val_steps=0) | |
trainer.fit(model) | |
@hydra.main(config_path="conf/config.yaml") | |
def run_model(cfg: DictConfig) -> None: | |
print(cfg.pretty()) | |
run(cfg) | |
if __name__ == "__main__": | |
run_model() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment