Created
November 29, 2017 09:07
-
-
Save burgalon/31360161193cd8065c51db5cc1e9b5f6 to your computer and use it in GitHub Desktop.
finetune imagenet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import torch | |
import torch.backends.cudnn as cudnn | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.optim | |
import torch.utils.data | |
import torchvision.transforms as transforms | |
from tensorboard_logger import configure | |
from torch.utils import model_zoo | |
from torch.utils.data.sampler import RandomSampler | |
BATCH_SIZE = 100 | |
parser = argparse.ArgumentParser(description='PyTorch Training') | |
parser.add_argument('--epochs', default=300, type=int, | |
help='number of total epochs to run') | |
parser.add_argument('-b', '--batch-size', default=BATCH_SIZE, type=int, | |
help='mini-batch size') | |
parser.add_argument('--accum', default=1, type=int, | |
help='Accumulate loss') | |
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, | |
help='initial learning rate') | |
parser.add_argument('--momentum', default=0.9, type=float, help='momentum') | |
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, | |
help='weight decay (default: 1e-4)') | |
parser.add_argument('--print-freq', '-p', default=10, type=int, | |
help='print frequency (default: 10)') | |
parser.add_argument('--save-steps-freq', default=1000, type=int, | |
help='Frequency to checkpoint and log to tensorboard') | |
parser.add_argument('--resume', default='', type=str, | |
help='path to latest checkpoint (default: none)') | |
parser.add_argument('--name', default='inceptionresnetv2_2', type=str, | |
help='name of experiment') | |
out_dir = data_dir + 'out/' | |
def main(): | |
global args | |
args = parser.parse_args() | |
configure(out_dir + "logs/%s" % (args.name)) | |
best_prec1 = 0 | |
best_train_prec1 = 0 | |
samples = 0 | |
# Data loading code | |
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], | |
std=[0.5, 0.5, 0.5]) | |
transform_train = transforms.Compose([ | |
transforms.RandomHorizontalFlip(), | |
transforms.CenterCrop(160), | |
# transforms.RandomCrop(160), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
transform_test = transforms.Compose([ | |
transforms.CenterCrop(160), | |
transforms.ToTensor(), | |
normalize | |
]) | |
kwargs = {'num_workers': 3, 'pin_memory': True} | |
train_dataset = TrainDataset(transform=transform_train) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
sampler=RandomSampler(train_dataset), | |
batch_size=args.batch_size, | |
**kwargs) | |
val_dataset = ValDataset(transform=transform_test) | |
val_loader = torch.utils.data.DataLoader( | |
val_dataset, | |
sampler=RandomSampler(val_dataset), | |
batch_size=args.batch_size, | |
**kwargs) | |
model = inceptionresnetv2(num_classes=NUM_CLASSES, pretrained='imagenet+background') | |
model = torch.nn.DataParallel(model).cuda() | |
# define loss function (criterion) and pptimizer | |
criterion = nn.CrossEntropyLoss().cuda() | |
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), args.lr, | |
momentum=args.momentum, | |
# nesterov=True, | |
weight_decay=args.weight_decay) | |
scheduler = None | |
# optionally resume from a checkpoint | |
if args.resume=='': | |
args.resume = out_dir + args.name + '/checkpoint.pth.tar' | |
if not os.path.isfile(args.resume): | |
args.resume = None | |
if args.resume: | |
if os.path.isfile(args.resume): | |
print("=> loading checkpoint '{}'".format(args.resume)) | |
checkpoint: dict = torch.load(args.resume) | |
best_prec1 = checkpoint['best_prec1'] | |
best_train_prec1 = checkpoint.get('best_train_prec1', 0) | |
samples = checkpoint.get('samples', None) | |
if samples is None: | |
samples = 0 | |
model.load_state_dict(checkpoint['state_dict']) | |
print("=> loaded checkpoint '{}' samples={} best_prec1={} best_train_prec1={}" | |
.format(args.resume, samples, best_prec1, best_train_prec1)) | |
else: | |
print("=> no checkpoint found at '{}'".format(args.resume)) | |
raise ValueError | |
cudnn.benchmark = True | |
print(args) | |
for epoch in range(samples//len(train_dataset), args.epochs): | |
adjust_learning_rate(optimizer, epoch, args.lr) | |
# train for one epoch | |
best_prec1, best_train_prec1, samples = train(train_dataset, train_loader, model, criterion, optimizer, val_loader, | |
best_prec1, best_train_prec1, samples, | |
out_dir + args.name + '/', args, scheduler) | |
print('Best accuracy: ', best_prec1) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import shutil | |
import time | |
import torch | |
from tensorboard_logger import log_value | |
def train(train_dataset, train_loader, model, criterion, optimizer, val_loader, best_prec1, best_train_prec1, samples, checkpoint_directory, args, scheduler): | |
"""Train for one epoch on the training set""" | |
batch_time = AverageMeter() | |
losses = AverageMeter() | |
top1 = AverageMeter() | |
# switch to train mode | |
model.train() | |
end = time.time() | |
for i, (input, target) in enumerate(train_loader): | |
if scheduler is not None: | |
scheduler.batch_step() | |
target = target.cuda(async=True) | |
input = input.cuda() | |
input_var = torch.autograd.Variable(input) | |
target_var = torch.autograd.Variable(target) | |
# compute output | |
output = model(input_var) | |
loss = criterion(output, target_var) | |
# measure accuracy and record loss | |
prec1 = accuracy(output.data, target, topk=(1,))[0] | |
losses.update(loss.data[0], input.size(0)) | |
top1.update(prec1[0], input.size(0)) | |
samples += input.size(0) | |
# compute gradient and do SGD step | |
loss.backward() | |
if i % args.accum == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if i % args.print_freq == 0: | |
print('Epoch: {0:.4f}\t' | |
'Step: {1}/{2}\t' | |
'Samples: [{samples}]\t' | |
'LR: {lr}\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Samples/s {samples_per_sec:.0f}\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( | |
samples / len(train_dataset), i, len(train_loader), samples=samples, batch_time=batch_time, | |
samples_per_sec=input.size(0)/batch_time.avg, | |
lr=get_learning_rate(optimizer)[0],# *iter_accum ??? | |
loss=losses, top1=top1)) | |
if i % args.save_steps_freq == 0: | |
if i>0: | |
# evaluate on validation set | |
prec1 = validate(val_loader, model, criterion, samples, args) | |
# remember best prec@1 and save checkpoint | |
print('Checkpoint') | |
is_best = prec1 > best_prec1 | |
best_prec1 = max(prec1, best_prec1) | |
is_best_train = top1.avg > best_train_prec1 | |
best_train_prec1 = max(top1.avg, best_train_prec1) | |
save_checkpoint({ | |
'samples': samples, | |
'state_dict': model.state_dict(), | |
'best_prec1': best_prec1, | |
'best_train_prec1': best_train_prec1, | |
'train_prec1': top1.avg, | |
}, is_best, is_best_train, | |
directory=checkpoint_directory | |
) | |
# log to TensorBoard | |
log_value('train_loss', losses.avg, samples) | |
log_value('train_acc', top1.avg, samples) | |
log_value('learning_rate', get_learning_rate(optimizer)[0], samples) | |
log_value('batch_size', input.size(0), samples) | |
log_value('effective_batch_size', input.size(0)*args.accum, samples) | |
log_value('accum', args.accum, samples) | |
batch_time.reset() | |
losses.reset() | |
top1.reset() | |
return best_prec1, best_train_prec1, samples | |
def validate(val_loader, model, criterion, samples, args): | |
"""Perform validation on the validation set""" | |
batch_time = AverageMeter() | |
losses = AverageMeter() | |
top1 = AverageMeter() | |
# switch to evaluate mode | |
model.eval() | |
end = time.time() | |
for i, (input, target) in enumerate(val_loader): | |
# print("input={}", input.size()) | |
target = target.cuda(async=True) | |
input = input.cuda() | |
input_var = torch.autograd.Variable(input, volatile=True) | |
target_var = torch.autograd.Variable(target, volatile=True) | |
# compute output | |
output = model(input_var) | |
# print("validate vars input={} target={} output={}".format(input_var.size(), target_var.size(), output.size())) | |
loss = criterion(output, target_var) | |
# measure accuracy and record loss | |
prec1 = accuracy(output.data, target, topk=(1,))[0] | |
losses.update(loss.data[0], input.size(0)) | |
top1.update(prec1[0], input.size(0)) | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if i % args.print_freq == 0: | |
print('Test: [{0}/{1}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( | |
i, len(val_loader), batch_time=batch_time, loss=losses, | |
top1=top1)) | |
print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) | |
# log to TensorBoard | |
log_value('val_loss', losses.avg, samples) | |
log_value('val_acc', top1.avg, samples) | |
return top1.avg | |
def get_learning_rate(optimizer): | |
if optimizer is None: | |
return [0.0] | |
lr=[] | |
for param_group in optimizer.param_groups: | |
lr +=[ param_group['lr'] ] | |
return lr | |
def save_checkpoint(state, is_best, is_best_train, directory, filename='checkpoint.pth.tar'): | |
"""Saves checkpoint to disk""" | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
filename = directory + filename | |
torch.save(state, filename) | |
if is_best: | |
shutil.copyfile(filename, directory + 'model_best.pth.tar') | |
if is_best_train: | |
shutil.copyfile(filename, directory + 'model_best_train.pth.tar') | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def adjust_learning_rate(optimizer, epoch, lr): | |
"""Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs""" | |
lr = lr * (0.1 ** (epoch // 150)) * (0.1 ** (epoch // 225)) | |
# log to TensorBoard | |
log_value('learning_rate', lr, epoch) | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
def accuracy(output, target, topk=(1,)): | |
"""Computes the precision@k for the specified values of k""" | |
maxk = max(topk) | |
batch_size = target.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].view(-1).float().sum(0) | |
res.append(correct_k.mul_(100.0 / batch_size)) | |
return res |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment