Created
November 15, 2019 16:24
-
-
Save williamFalcon/0cf89350e1efb3e269e43c52c06f1417 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning as pl | |
def main(hparams): | |
# init model | |
model = FastStyleTransfer(hparams) | |
trainer = pl.Trainer() | |
trainer.fit() | |
if __name__ == '__main__': | |
main_arg_parser = argparse.ArgumentParser(description="parser for fast-neural-style") | |
main_arg_parser.add_argument("--save-model-dir", type=str, required=True, | |
help="path to folder where trained model will be saved.") | |
main_arg_parser.add_argument("--checkpoint-model-dir", type=str, default=None, | |
help="path to folder where checkpoints of trained models will be saved") | |
main_arg_parser.add_argument("--cuda", type=int, required=True, | |
help="set it to 1 for running on GPU, 0 for CPU") | |
main_arg_parser.add_argument("--seed", type=int, default=42, | |
help="random seed for training") | |
main_arg_parser.add_argument("--log-interval", type=int, default=500, | |
help="number of images after which the training loss is logged, default is 500") | |
main_arg_parser.add_argument("--export_onnx", type=str, | |
help="export ONNX model to a given file") | |
main_arg_parser.add_argument("--checkpoint-interval", type=int, default=2000, | |
help="number of batches after which a checkpoint of the trained model will be created") | |
# add model specific args | |
parser = FastStyleTransfer.add_model_specific_args(main_arg_parser, os.getcwd()) | |
args = parser.parse_args() | |
main(args) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment