Created
March 2, 2017 23:04
-
-
Save panovr/2977d9f26866b05583b0c40d88a315bf to your computer and use it in GitHub Desktop.
Fine-tuning pre-trained models with PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import shutil | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.optim | |
import torch.utils.data | |
import torchvision.transforms as transforms | |
import torchvision.datasets as datasets | |
import torchvision.models as models | |
model_names = sorted(name for name in models.__dict__ | |
if name.islower() and not name.startswith("__")) | |
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') | |
parser.add_argument('data', metavar='DIR', | |
help='path to dataset') | |
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', | |
choices=model_names, | |
help='model architecture: ' + | |
' | '.join(model_names) + | |
' (default: resnet18)') | |
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | |
help='number of data loading workers (default: 4)') | |
parser.add_argument('--epochs', default=90, type=int, metavar='N', | |
help='number of total epochs to run') | |
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', | |
help='manual epoch number (useful on restarts)') | |
parser.add_argument('-b', '--batch-size', default=256, type=int, | |
metavar='N', help='mini-batch size (default: 256)') | |
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, | |
metavar='LR', help='initial learning rate') | |
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', | |
help='momentum') | |
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, | |
metavar='W', help='weight decay (default: 1e-4)') | |
parser.add_argument('--print-freq', '-p', default=10, type=int, | |
metavar='N', help='print frequency (default: 10)') | |
parser.add_argument('--resume', default='', type=str, metavar='PATH', | |
help='path to latest checkpoint (default: none)') | |
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', | |
help='evaluate model on validation set') | |
parser.add_argument('--pretrained', dest='pretrained', action='store_true', | |
help='use pre-trained model') | |
parser.add_argument('--finetune', dest='finetune', action='store_true', | |
help='fine tune pre-trained model') | |
best_prec1 = 0 | |
class FineTuneModel(nn.Module): | |
def __init__(self, original_model, arch, num_classes): | |
super(FineTuneModel, self).__init__() | |
if arch.startswith('alexnet') : | |
self.features = original_model.features | |
self.classifier = nn.Sequential( | |
nn.Dropout(), | |
nn.Linear(256 * 6 * 6, 4096), | |
nn.ReLU(inplace=True), | |
nn.Dropout(), | |
nn.Linear(4096, 4096), | |
nn.ReLU(inplace=True), | |
nn.Linear(4096, num_classes), | |
) | |
self.modelName = 'alexnet' | |
elif arch.startswith('resnet') : | |
# Everything except the last linear layer | |
self.features = nn.Sequential(*list(original_model.children())[:-1]) | |
self.classifier = nn.Sequential( | |
nn.Linear(512, num_classes) | |
) | |
self.modelName = 'resnet' | |
elif arch.startswith('vgg16'): | |
self.features = original_model.features | |
self.classifier = nn.Sequential( | |
nn.Dropout(), | |
nn.Linear(25088, 4096), | |
nn.ReLU(inplace=True), | |
nn.Dropout(), | |
nn.Linear(4096, 4096), | |
nn.ReLU(inplace=True), | |
nn.Linear(4096, num_classes), | |
) | |
self.modelName = 'vgg16' | |
else : | |
raise("Finetuning not supported on this architecture yet") | |
# Freeze those weights | |
for p in self.features.parameters(): | |
p.requires_grad = False | |
def forward(self, x): | |
f = self.features(x) | |
if self.modelName == 'alexnet' : | |
f = f.view(f.size(0), 256 * 6 * 6) | |
elif self.modelName == 'vgg16': | |
f = f.view(f.size(0), -1) | |
elif self.modelName == 'resnet' : | |
f = f.view(f.size(0), -1) | |
y = self.classifier(f) | |
return y | |
def main(): | |
global args, best_prec1 | |
args = parser.parse_args() | |
traindir = os.path.join(args.data, 'train') | |
valdir = os.path.join(args.data, 'val') | |
# Get number of classes from train directory | |
num_classes = len([name for name in os.listdir(traindir)]) | |
print("num_classes = '{}'".format(num_classes)) | |
# create model | |
if args.finetune: | |
print("=> using pre-trained model '{}'".format(args.arch)) | |
original_model = models.__dict__[args.arch](pretrained=True) | |
model = FineTuneModel(original_model, args.arch, num_classes) | |
else: | |
print("=> creating model '{}'".format(args.arch)) | |
model = models.__dict__[args.arch]() | |
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): | |
model.features = torch.nn.DataParallel(model.features) | |
model.cuda() | |
else: | |
model = torch.nn.DataParallel(model).cuda() | |
# optionally resume from a checkpoint | |
if args.resume: | |
if os.path.isfile(args.resume): | |
print("=> loading checkpoint '{}'".format(args.resume)) | |
checkpoint = torch.load(args.resume) | |
args.start_epoch = checkpoint['epoch'] | |
best_prec1 = checkpoint['best_prec1'] | |
model.load_state_dict(checkpoint['state_dict']) | |
print("=> loaded checkpoint '{}' (epoch {})" | |
.format(args.evaluate, checkpoint['epoch'])) | |
else: | |
print("=> no checkpoint found at '{}'".format(args.resume)) | |
cudnn.benchmark = True | |
# Data loading code | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
train_loader = torch.utils.data.DataLoader( | |
datasets.ImageFolder(traindir, transforms.Compose([ | |
transforms.RandomSizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
normalize, | |
])), | |
batch_size=args.batch_size, shuffle=True, | |
num_workers=args.workers, pin_memory=True) | |
val_loader = torch.utils.data.DataLoader( | |
datasets.ImageFolder(valdir, transforms.Compose([ | |
transforms.Scale(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
normalize, | |
])), | |
batch_size=args.batch_size, shuffle=False, | |
num_workers=args.workers, pin_memory=True) | |
# define loss function (criterion) and pptimizer | |
criterion = nn.CrossEntropyLoss().cuda() | |
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), # Only finetunable params | |
args.lr, | |
momentum=args.momentum, | |
weight_decay=args.weight_decay) | |
if args.evaluate: | |
validate(val_loader, model, criterion) | |
return | |
for epoch in range(args.start_epoch, args.epochs): | |
adjust_learning_rate(optimizer, epoch) | |
# train for one epoch | |
train(train_loader, model, criterion, optimizer, epoch) | |
# evaluate on validation set | |
prec1 = validate(val_loader, model, criterion) | |
# remember best prec@1 and save checkpoint | |
is_best = prec1 > best_prec1 | |
best_prec1 = max(prec1, best_prec1) | |
save_checkpoint({ | |
'epoch': epoch + 1, | |
'arch': args.arch, | |
'state_dict': model.state_dict(), | |
'best_prec1': best_prec1, | |
}, is_best) | |
def train(train_loader, model, criterion, optimizer, epoch): | |
batch_time = AverageMeter() | |
data_time = AverageMeter() | |
losses = AverageMeter() | |
top1 = AverageMeter() | |
top5 = AverageMeter() | |
# switch to train mode | |
model.train() | |
end = time.time() | |
for i, (input, target) in enumerate(train_loader): | |
# measure data loading time | |
data_time.update(time.time() - end) | |
target = target.cuda(async=True) | |
input_var = torch.autograd.Variable(input) | |
target_var = torch.autograd.Variable(target) | |
# compute output | |
output = model(input_var) | |
loss = criterion(output, target_var) | |
# measure accuracy and record loss | |
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | |
losses.update(loss.data[0], input.size(0)) | |
top1.update(prec1[0], input.size(0)) | |
top5.update(prec5[0], input.size(0)) | |
# compute gradient and do SGD step | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if i % args.print_freq == 0: | |
print('Epoch: [{0}][{1}/{2}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | |
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | |
epoch, i, len(train_loader), batch_time=batch_time, | |
data_time=data_time, loss=losses, top1=top1, top5=top5)) | |
def validate(val_loader, model, criterion): | |
batch_time = AverageMeter() | |
losses = AverageMeter() | |
top1 = AverageMeter() | |
top5 = AverageMeter() | |
# switch to evaluate mode | |
model.eval() | |
end = time.time() | |
for i, (input, target) in enumerate(val_loader): | |
target = target.cuda(async=True) | |
input_var = torch.autograd.Variable(input, volatile=True) | |
target_var = torch.autograd.Variable(target, volatile=True) | |
# compute output | |
output = model(input_var) | |
loss = criterion(output, target_var) | |
# measure accuracy and record loss | |
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | |
losses.update(loss.data[0], input.size(0)) | |
top1.update(prec1[0], input.size(0)) | |
top5.update(prec5[0], input.size(0)) | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if i % args.print_freq == 0: | |
print('Test: [{0}/{1}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | |
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | |
i, len(val_loader), batch_time=batch_time, loss=losses, | |
top1=top1, top5=top5)) | |
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' | |
.format(top1=top1, top5=top5)) | |
return top1.avg | |
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): | |
torch.save(state, filename) | |
if is_best: | |
shutil.copyfile(filename, 'model_best.pth.tar') | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def adjust_learning_rate(optimizer, epoch): | |
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" | |
lr = args.lr * (0.1 ** (epoch // 30)) | |
for param_group in optimizer.state_dict()['param_groups']: | |
param_group['lr'] = lr | |
def accuracy(output, target, topk=(1,)): | |
"""Computes the precision@k for the specified values of k""" | |
maxk = max(topk) | |
batch_size = target.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].view(-1).float().sum(0) | |
res.append(correct_k.mul_(100.0 / batch_size)) | |
return res | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
thank you kindly 🥇