Last active
September 27, 2019 00:28
-
-
Save AnchorBlues/95dab2afa2f9ad875d89edac495a235a to your computer and use it in GitHub Desktop.
機械学習のモデル学習スクリプトでよく使うargpase
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == '__main__': | |
import argparse | |
# https://github.com/nn116003/self-attention-classification/blob/master/imdb_attn.pyより | |
parser = argparse.ArgumentParser(description='PyTorch IMDB Example') | |
parser.add_argument('--h-dim', type=int, default=32, metavar='N', | |
help='hidden state dim (default: 32)') | |
parser.add_argument('--emb_dim', type=int, default=100, metavar='N', | |
help='word embedding dim (default: 100)') | |
parser.add_argument('--batch-size', type=int, default=32, metavar='N', | |
help='input batch size for training (default: 32)') | |
parser.add_argument('--epochs', type=int, default=3, metavar='N', | |
help='number of epochs to train (default: 3)') | |
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', | |
help='learning rate (default: 0.001)') | |
parser.add_argument('--no-cuda', action='store_true', default=False, | |
help='disables CUDA training') | |
parser.add_argument('--log-interval', type=int, default=10, metavar='N', | |
help='how many batches to wait before logging training status') | |
parser.add_argument('--save_dir', type=str, | |
default="../models/weights/results1/", | |
help='where to save the training result') | |
# https://github.com/pytorch/examples/blob/master/mnist/main.py | |
parser.add_argument('--seed', type=int, default=1, metavar='S', | |
help='random seed (default: 1)') | |
# https://docs.python.org/ja/3/library/argparse.html | |
parser.add_argument('--move', choices=['rock', 'paper', 'scissors']) | |
# https://github.com/eriklindernoren/PyTorch-YOLOv3/blob/master/train.py | |
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") | |
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension") | |
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights") | |
parser.add_argument("--evaluation_interval", type=int, default=1, help="interval evaluations on validation set") | |
# https://github.com/eriklindernoren/PyTorch-YOLOv3/blob/master/test.py | |
parser.add_argument("--weights_path", type=str, default="weights/yolov3.weights", help="path to weights file") | |
args = parser.parse_args() | |
device = torch.device("cuda" if not args.no_cuda and torch.cuda.is_available() else "cpu") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
モデルを学習するスクリプトであれば、以下のようにargmentを
にわけて整理すると非常に見やすくなります。