Last active
November 14, 2018 22:30
-
-
Save jeasinema/01a2ebed8194c1f716223a6a3f06f58e to your computer and use it in GitHub Desktop.
A naive pytorch template
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import time | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.utils.data import Dataset | |
import torchvision.models as models | |
from optim.lr_scheduler import StepLR | |
import numpy as np | |
from tensorboardX import SummaryWriter | |
from MobileNetV2 import MobileNetV2 | |
from foodnet_loader import FoodnetDataset | |
parser = argparse.ArgumentParser(description='foodnet') | |
parser.add_argument('--tag', type=str, default='default') | |
parser.add_argument('--epoch', type=int, default=200) | |
parser.add_argument('--seed', type=int, default=1) | |
parser.add_argument('--mode', choices=['train', 'test'], required=True) | |
parser.add_argument('--batch-size', type=int, default=256) | |
parser.add_argument('--cuda', action='store_true') | |
parser.add_argument('--gpu', type=int, default=0) | |
parser.add_argument('--lr', type=float, default=0.01) | |
parser.add_argument('--load-model', type=str, default='') | |
parser.add_argument('--load-epoch', type=int, default=-1 ) | |
parser.add_argument('--model-path', type=str, default='./assets/learned_models', | |
help='pre-trained model path') | |
parser.add_argument('--data-path', type=str, default='./data', help='data path') | |
parser.add_argument('--log-interval', type=int, default=10) | |
parser.add_argument('--save-interval', type=int, default=10) | |
args = parser.parse_args() | |
args.cuda = args.cuda if torch.cuda.is_available else False | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
if args.cuda: | |
torch.cuda.manual_seed(1) | |
logger = SummaryWriter(os.path.join('./assets/log/', args.tag)) | |
def worker_init_fn(pid): | |
np.random.seed(torch.initial_seed() % (2**31-1)) | |
def my_collate(batch): | |
batch = list(filter(lambda x:x is not None, batch)) | |
return default_collate(batch) | |
train_loader = torch.utils.data.DataLoader( | |
FoodnetDataset( | |
img_size=224, | |
split_file=os.path.join(args.data_path,'food-101/meta/train'), | |
path=args.data_path, | |
), | |
batch_size=args.batch_size, | |
num_workers=32, | |
pin_memory=True, | |
shuffle=True, | |
worker_init_fn=worker_init_fn, | |
collate_fn=my_collate, | |
) | |
test_loader = torch.utils.data.DataLoader( | |
FoodnetDataset( | |
img_size=224, | |
split_file=os.path.join(args.data_path,'food-101/meta/test'), | |
path=args.data_path, | |
), | |
batch_size=args.batch_size, | |
num_workers=32, | |
pin_memory=True, | |
shuffle=True, | |
worker_init_fn=worker_init_fn, | |
collate_fn=my_collate, | |
) | |
is_resume = 0 | |
if args.load_model and args.load_epoch != -1: | |
is_resume = 1 | |
model = MobileNetV2(n_class=101, input_size=224) | |
if is_resume or args.mode == 'test': | |
model.load_state_dict(torch.load(args.load_model, map_location='cpu')) | |
print('load model {}'.format(args.load_model)) | |
if args.cuda: | |
if args.gpu != -1: | |
torch.cuda.set_device(args.gpu) | |
model = model.cuda() | |
else: | |
device_id = [0,1,2,3] | |
torch.cuda.set_device(device_id[0]) # when 0 is not in device_id, must executed before dataparallel | |
model = nn.DataParallel(model.cuda(), device_ids=device_id) | |
optimizer = optim.Adam(model.parameters(), lr=args.lr) | |
scheduler = StepLR(optimizer, step_size=30, gamma=0.1) | |
def train(model, loader, epoch): | |
scheduler.step() | |
model.train() | |
torch.set_grad_enabled(True) | |
correct = 0 | |
dataset_size = 0 | |
for batch_idx, (data, target) in enumerate(loader): | |
dataset_size += data.shape[0] | |
data, target = torch.FloatTensor(data), torch.LongTensor(target).squeeze() | |
if args.cuda: | |
data, target = data.cuda(), target.cuda() | |
optimizer.zero_grad() | |
output = model(data) | |
loss = F.nll_loss(output, target) # because you've already using log_softmax as output | |
loss.backward() | |
optimizer.step() | |
pred = output.data.max(1, keepdim=True)[1] | |
correct += pred.eq(target.view_as(pred)).long().cpu().sum() | |
if batch_idx % args.log_interval == 0: | |
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t{}'.format( | |
epoch, batch_idx * args.batch_size, len(loader.dataset), | |
100. * batch_idx * args.batch_size / len(loader.dataset), loss.item(), args.tag)) | |
logger.add_scalar('train_loss', loss.item(), | |
batch_idx + epoch * len(loader)) | |
return float(correct)/float(dataset_size) | |
def test(model, loader): | |
model.eval() | |
torch.set_grad_enabled(False) | |
test_loss = 0 | |
correct = 0 | |
dataset_size = 0 | |
for data, target in loader: | |
dataset_size += data.shape[0] | |
data, target = torch.FloatTensor(data), torch.LongTensor(target).squeeze() | |
if args.cuda: | |
data, target = data.cuda(), target.cuda() | |
output = model(data) # N*C | |
test_loss += F.nll_loss(output, target, size_average=False).item() | |
pred = output.data.max(1, keepdim=True)[1] | |
correct += pred.eq(target.view_as(pred)).long().cpu().sum() | |
test_loss /= float(dataset_size) | |
acc = float(correct)/float(dataset_size) | |
return acc, test_loss | |
def main(): | |
if args.mode == 'train': | |
for epoch in range(is_resume*args.load_epoch, args.epoch): | |
acc_train = train(model, train_loader, epoch) | |
print('Train done, acc={}'.format(acc_train)) | |
acc, loss = test(model, test_loader) | |
print('Test done, acc={}, loss={}'.format(acc, loss)) | |
logger.add_scalar('train_acc', acc_train, epoch) | |
logger.add_scalar('test_acc', acc, epoch) | |
logger.add_scalar('test_loss', loss, epoch) | |
if epoch % args.save_interval == 0: | |
path = os.path.join(args.model_path, args.tag + '_{}.model'.format(epoch)) | |
torch.save(model.cpu().state_dict(), path) | |
print('Save model @ {}'.format(path)) | |
else: | |
acc, loss = test(model, test_loader) | |
print('Test done, acc={}, loss={}'.format(acc, loss)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment