Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Created November 15, 2019 16:24
Show Gist options
  • Save williamFalcon/0cf89350e1efb3e269e43c52c06f1417 to your computer and use it in GitHub Desktop.
Save williamFalcon/0cf89350e1efb3e269e43c52c06f1417 to your computer and use it in GitHub Desktop.
import pytorch_lightning as pl
def main(hparams):
# init model
model = FastStyleTransfer(hparams)
trainer = pl.Trainer()
trainer.fit()
if __name__ == '__main__':
main_arg_parser = argparse.ArgumentParser(description="parser for fast-neural-style")
main_arg_parser.add_argument("--save-model-dir", type=str, required=True,
help="path to folder where trained model will be saved.")
main_arg_parser.add_argument("--checkpoint-model-dir", type=str, default=None,
help="path to folder where checkpoints of trained models will be saved")
main_arg_parser.add_argument("--cuda", type=int, required=True,
help="set it to 1 for running on GPU, 0 for CPU")
main_arg_parser.add_argument("--seed", type=int, default=42,
help="random seed for training")
main_arg_parser.add_argument("--log-interval", type=int, default=500,
help="number of images after which the training loss is logged, default is 500")
main_arg_parser.add_argument("--export_onnx", type=str,
help="export ONNX model to a given file")
main_arg_parser.add_argument("--checkpoint-interval", type=int, default=2000,
help="number of batches after which a checkpoint of the trained model will be created")
# add model specific args
parser = FastStyleTransfer.add_model_specific_args(main_arg_parser, os.getcwd())
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment