Skip to content

Instantly share code, notes, and snippets.

@AnchorBlues
Last active September 27, 2019 00:28
Show Gist options
  • Save AnchorBlues/95dab2afa2f9ad875d89edac495a235a to your computer and use it in GitHub Desktop.
Save AnchorBlues/95dab2afa2f9ad875d89edac495a235a to your computer and use it in GitHub Desktop.
機械学習のモデル学習スクリプトでよく使うargpase
if __name__ == '__main__':
import argparse
# https://github.com/nn116003/self-attention-classification/blob/master/imdb_attn.pyより
parser = argparse.ArgumentParser(description='PyTorch IMDB Example')
parser.add_argument('--h-dim', type=int, default=32, metavar='N',
help='hidden state dim (default: 32)')
parser.add_argument('--emb_dim', type=int, default=100, metavar='N',
help='word embedding dim (default: 100)')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=3, metavar='N',
help='number of epochs to train (default: 3)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save_dir', type=str,
default="../models/weights/results1/",
help='where to save the training result')
# https://github.com/pytorch/examples/blob/master/mnist/main.py
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
# https://docs.python.org/ja/3/library/argparse.html
parser.add_argument('--move', choices=['rock', 'paper', 'scissors'])
# https://github.com/eriklindernoren/PyTorch-YOLOv3/blob/master/train.py
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
parser.add_argument("--evaluation_interval", type=int, default=1, help="interval evaluations on validation set")
# https://github.com/eriklindernoren/PyTorch-YOLOv3/blob/master/test.py
parser.add_argument("--weights_path", type=str, default="weights/yolov3.weights", help="path to weights file")
args = parser.parse_args()
device = torch.device("cuda" if not args.no_cuda and torch.cuda.is_available() else "cpu")
@AnchorBlues
Copy link
Author

モデルを学習するスクリプトであれば、以下のようにargmentを

  • basic settings
  • data
  • data loader
  • model
  • training

にわけて整理すると非常に見やすくなります。

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='')
    # basic settings
    parser.add_argument('--seed', type=int, default=1,
                        help='random seed')
    parser.add_argument('--save_dir', type=str,
                        default="../../models/weights/sngan",
                        help='where to save the training result')
    # data
    parser.add_argument("--image_dir", type=str,
                        default="../../data/raw/coco/train2014/",
                        help="directory including coco image files")
    parser.add_argument("--encoded_vec_train_path", type=str,
                        default="../../data/external/cnn_dcnn_results/train_encoded_vectors.npy",
                        help="numpy file path of encoded vectors(train)")
    parser.add_argument("--encoded_vec_valid_path", type=str,
                        default="../../data/external/cnn_dcnn_results/valid_encoded_vectors.npy",
                        help="numpy file path of encoded vectors(valid)")
    parser.add_argument("--df_train_path", type=str,
                        default="../../data/interim/coco/train.csv",
                        help="path to data frame file of train dataset")
    parser.add_argument("--df_valid_path", type=str,
                        default="../../data/interim/coco/valid.csv",
                        help="path to data frame file of valid dataset")

    # data loader
    parser.add_argument('--batch_size', type=int, default=128,
                        help='batch size for training')

    # model
    parser.add_argument("--model_config_path", type=str,
                        default="../../configs/models/settings.cfg",
                        help="")
    parser.add_argument('--d_weights_path', type=str, default="",
                        help='path to weights file of Discriminator. if not set, weights will be initialized.')
    parser.add_argument('--g_weights_path', type=str, default="",
                        help='path to weights file of Generator. if not set, weights will be initialized.')

    # training
    parser.add_argument('--lr', type=float, default=1e-3,
                        help='learning rate')
    parser.add_argument('--n_epochs', type=int, default=20,
                        help='number of epochs for train')
    parser.add_argument('-save_interval', type=int, default=5,
                        help='how many epochs to wait before saving')
    parser.add_argument("--reg_criterion", type=str, default="L2",
                        choices=['L1', 'L2'], help="loss function of regression loss")
    parser.add_argument("--reg_lambda", type=float, default=10,
                        help="coefficient of regression loss")

    args = parser.parse_args()
    print(args)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment