Last active
July 23, 2018 09:01
-
-
Save zeakey/9d1c313329a7ea32ea12ae0f3a8db09f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# multi-class cross-entropy loss with center-exclusive | |
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from torch.optim import lr_scheduler | |
torch.backends.cudnn.bencmark = True | |
import os, sys, random, datetime, time | |
from os.path import isdir, isfile, isdir, join, dirname, abspath | |
import argparse, datetime | |
import numpy as np | |
from PIL import Image | |
from scipy.io import savemat | |
import torchvision.datasets as datasets | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
import matplotlib.pyplot as plt | |
from utils import accuracy, test_lfw, AverageMeter, save_checkpoint, str2bool, Logger | |
THIS_DIR = abspath(dirname(__file__)) | |
TMP_DIR = join(THIS_DIR, 'tmp') | |
if not isdir(TMP_DIR): | |
os.makedirs(TMP_DIR) | |
parser = argparse.ArgumentParser(description='PyTorch Implementation of HED.') | |
parser.add_argument('--bs', type=int, help='batch size', default=600) | |
# optimizer parameters | |
parser.add_argument('--lr', type=float, help='base learning rate', default=0.1) | |
parser.add_argument('--momentum', type=float, help='momentum', default=0.9) | |
parser.add_argument('--stepsize', type=float, help='step size (epoch)', default=18) | |
parser.add_argument('--gamma', type=float, help='gamma', default=0.1) | |
parser.add_argument('--wd', type=float, help='weight decay', default=5e-4) | |
parser.add_argument('--maxepoch', type=int, help='maximal training epoch', default=30) | |
# model parameters | |
parser.add_argument('--exclusive_weight', type=float, help='center exclusive loss weight', default=6) | |
parser.add_argument('--radius', type=float, help='radius', default=15) | |
parser.add_argument('--l2filter', type=str, help='filter samples based on l2', default="True") | |
parser.add_argument('--warmup', type=int, help='warmup epoch', default=0) | |
# general parameters | |
parser.add_argument('--print_freq', type=int, help='print frequency', default=50) | |
parser.add_argument('--train', type=str, help='set to false to test lfw acc only', default="true") | |
parser.add_argument('--cuda', type=int, help='cuda', default=1) | |
parser.add_argument('--debug', type=str, help='debug mode', default='false') | |
parser.add_argument('--checkpoint', type=str, help='checkpoint prefix', default="center_exclusive") | |
parser.add_argument('--resume', type=str, help='checkpoint path', default=None) | |
parser.add_argument('--parallel', action='store_true') | |
# datasets | |
parser.add_argument('--casia', type=str, help='root folder of CASIA-WebFace dataset', default="data/CASIA-WebFace-112X96") | |
parser.add_argument('--num_class', type=int, help='num classes', default=10572) | |
parser.add_argument('--lfw', type=str, help='LFW dataset root folder', default="data/lfw-112X96") | |
parser.add_argument('--lfwlist', type=str, help='lfw image list', default='data/LFW_imagelist.txt') | |
args = parser.parse_args() | |
assert isfile(args.lfwlist) and isdir(args.lfw) | |
assert args.exclusive_weight > 0 | |
assert args.cuda == 1 | |
args.train = str2bool(args.train) | |
args.l2filter = str2bool(args.l2filter) | |
if args.l2filter: | |
args.checkpoint = join(TMP_DIR, args.checkpoint) + "-filter-exclusive_weight%.2f-radius%.1f-warmup%d-" % \ | |
(args.exclusive_weight, args.radius, args.warmup) + \ | |
datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") | |
else: | |
args.checkpoint = join(TMP_DIR, args.checkpoint) + "-exclusive_weight%.2f-radius%.1f-warmup%d-" % \ | |
(args.exclusive_weight, args.radius, args.warmup) + \ | |
datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") | |
print("Checkpoint directory: %s" % args.checkpoint) | |
if not isdir(args.checkpoint): | |
os.makedirs(args.checkpoint) | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
if args.train: | |
print("Pre-loading training data...") | |
train_dataset = datasets.ImageFolder( | |
args.casia, | |
transforms.Compose([ | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=args.bs, shuffle=True, | |
num_workers=8, pin_memory=True | |
) | |
print("Done!") | |
# transforms for LFW testing data | |
test_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
normalize | |
]) | |
# model and optimizer | |
from models import CenterExclusive | |
print("Loading model...") | |
model = CenterExclusive(num_class=args.num_class, norm_data=True, radius=args.radius) | |
print("Done!") | |
# optimizer related | |
criterion = nn.CrossEntropyLoss(reduce = not args.l2filter) | |
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momentum) | |
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma) | |
# scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[20, 30, 35], gamma=0.5) | |
if args.cuda: | |
print("Transporting model to GPU(s)...") | |
model.cuda() | |
print("Done!") | |
if args.parallel: | |
model = nn.DataParallel(model) | |
def train_epoch(train_loader, model, optimizer, epoch): | |
# recording | |
loss_cls = AverageMeter() | |
loss_exc = AverageMeter() | |
top1 = AverageMeter() | |
batch_time = AverageMeter() | |
train_record = np.zeros((len(train_loader), 4), np.float32) # loss, exc_loss, top1-acc, lr | |
# switch to train mode | |
model.train() | |
for batch_idx, (data, label) in enumerate(train_loader): | |
it = epoch * len(train_loader) + batch_idx | |
# exclusive loss weight | |
if epoch < args.warmup: | |
# exclusive_weight = float(it) / (args.warmup * len(train_loader)) * args.exclusive_weight | |
exclusive_weight = float(epoch) / args.warmup * args.exclusive_weight | |
# exclusive_weight = 0 | |
else: | |
exclusive_weight = args.exclusive_weight | |
start_time = time.time() | |
if args.cuda: | |
data = data.cuda() | |
label = label.cuda(non_blocking=True) | |
prob, feature, center_exclusive_loss = model(data) | |
if args.parallel: | |
center_exclusive_loss = torch.mean(center_exclusive_loss) | |
########################################## | |
if args.l2filter: | |
bs = feature.size(0) | |
feature_l2 = torch.norm(feature, p=2, dim=1).detach() | |
feature_l2 = feature_l2.cpu().numpy() | |
assert feature_l2.min() > 0 | |
if False: | |
bad, hard = int(bs / 20), int(bs / 2) | |
bad_examples = (feature_l2 <= np.sort(feature_l2)[bad]) # bad examples will be eliminated | |
# hard examples will be emphasized | |
hard_examples = np.logical_and(feature_l2 > np.sort(feature_l2)[bad], feature_l2 < np.sort(feature_l2)[hard]) | |
normal_examples = np.logical_not(np.logical_or(bad_examples, hard_examples)) | |
weight = feature_l2.copy() | |
weight[normal_examples] = 1 | |
weight[bad_examples] = 0 | |
weight[hard_examples] /= weight[hard_examples].max() | |
weight[hard_examples] = 1 / weight[hard_examples] | |
else: | |
num_decay = int(feature_l2.size / 10) | |
decay_examples = feature_l2 < np.sort(feature_l2)[num_decay] | |
normal_examples = np.logical_not(decay_examples) | |
weight = feature_l2.copy() | |
weight[normal_examples] = 1 | |
weight[decay_examples] -= weight[decay_examples].min() | |
weight[decay_examples] /= weight[decay_examples].max() | |
# weight[decay_examples] = 1 / weight[decay_examples] | |
loss = criterion(prob, label) | |
loss = torch.mul(loss, torch.from_numpy(weight).cuda()).mean() | |
else: | |
loss = criterion(prob, label) | |
########################################## | |
# measure accuracy and record loss | |
prec1, prec5 = accuracy(prob, label, topk=(1, 5)) | |
loss_cls.update(loss.item(), data.size(0)) | |
loss_exc.update(center_exclusive_loss.item(), data.size(0)) | |
top1.update(prec1[0], data.size(0)) | |
# collect losses | |
loss = loss + exclusive_weight * center_exclusive_loss | |
# clear cached gradient | |
optimizer.zero_grad() | |
# backward gradient | |
loss.backward() | |
# update parameters | |
optimizer.step() | |
batch_time.update(time.time() - start_time) | |
if batch_idx % args.print_freq == 0: | |
print("Epoch %d/%d Batch %d/%d, (sec/batch: %.2fsec): loss_cls=%.3f (* 1), loss-exc=%.5f (* %.4f), acc1=%.3f, lr=%.3f" % \ | |
# (epoch, args.maxepoch, batch_idx, len(train_loader), batch_time.val, loss_cls.val, | |
# loss_exc.val, exclusive_weight, top1.val, scheduler.get_lr()[0])) | |
(epoch, args.maxepoch, batch_idx, len(train_loader), batch_time.val, loss.item(), | |
center_exclusive_loss.item(), exclusive_weight, prec1.item(), scheduler.get_lr()[0])) | |
if args.l2filter: | |
plt.scatter(feature_l2, weight) | |
plt.title("%dhard-%dbad" % (np.count_nonzero(np.logical_and(weight != 1, weight != 0)), np.count_nonzero(weight==0))) | |
plt.savefig(join(args.checkpoint, "Iter%d-feature-l2-vs-weight.jpg" % it)) | |
plt.close() | |
# train_record[batch_idx, :] = np.array([loss_cls.avg, loss_exc.avg, top1.avg / float(100), scheduler.get_lr()[0]]) | |
train_record[batch_idx, :] = np.array([loss.item(), center_exclusive_loss.item(), prec1.item() / float(100), scheduler.get_lr()[0]]) | |
return train_record | |
def main(): | |
lfw_acc_history = np.zeros((args.maxepoch, ), np.float32) | |
# Logging to text file | |
log = Logger(join(args.checkpoint, 'log.txt')) | |
sys.stdout = log | |
for epoch in range(args.maxepoch): | |
scheduler.step() # will adjust learning rate | |
if args.train: | |
if epoch == 0: | |
train_record = train_epoch(train_loader, model, optimizer, epoch) | |
else: | |
train_record = np.vstack((train_record, train_epoch(train_loader, model, optimizer, epoch))) | |
# prepare data for testing | |
with open(args.lfwlist, 'r') as f: | |
imglist = f.readlines() | |
imglist = [join(args.lfw, i.rstrip()) for i in imglist] | |
lfw_acc_history[epoch] = test_lfw(model, imglist, test_transform, join(args.checkpoint, 'epoch%d' % epoch)) | |
save_checkpoint({ | |
'epoch': epoch + 1, | |
'state_dict': model.state_dict(), | |
'optimizer' : optimizer.state_dict(), | |
}, filename=join(args.checkpoint, "epoch%d-lfw%f.pth" % (epoch, lfw_acc_history[epoch]))) | |
print("Epoch %d best LFW accuracy is %.5f." % (epoch, lfw_acc_history.max())) | |
# instantly flush log to text file | |
log.flush() | |
# save logging figure | |
if args.train: | |
savemat(join(args.checkpoint, 'record(max-acc=%.5f).mat' % lfw_acc_history.max()), | |
dict({"train_record": train_record, | |
"lfw_acc_history": lfw_acc_history})) | |
fig, axes = plt.subplots(1, 5, figsize=(25, 5)) | |
for ax in axes: | |
ax.grid(True) | |
ax.hold(True) | |
axes[0].plot(train_record[:, 0], 'r') # loss cls | |
axes[0].set_title("CELoss") | |
axes[1].plot(train_record[:, 1], 'r') # loss exclusive | |
axes[1].set_title("ExLoss") | |
axes[2].plot(train_record[:, 2], 'r') # top1 acc | |
axes[2].set_title("Trn-Acc") | |
axes[3].plot(train_record[:, 3], 'r') # LR | |
axes[3].set_title("LR") | |
axes[4].plot(lfw_acc_history.argmax(), lfw_acc_history.max(), 'r*', markersize=12) | |
axes[4].plot(lfw_acc_history, 'r') | |
axes[4].set_title("LFW-Acc") | |
plt.suptitle("radius=%.1f, exclusive-loss $\\times$ %.1f max LFW-Acc=%.3f" % (args.radius, args.exclusive_weight, | |
lfw_acc_history.max())) | |
else: | |
savemat(join(args.checkpoint + 'record(max-acc=%.5f).mat' % lfw_acc_history.max()), | |
dict({"lfw_acc_history": lfw_acc_history})) | |
plt.plot(lfw_acc_history) | |
plt.legend(['LFW-Accuracy (max=%.5f)' % lfw_acc_history.max()]) | |
plt.grid(True) | |
plt.title("center-exclusive$\\times$%.1f" % args.exclusive_weight) | |
plt.savefig(join(args.checkpoint, 'record.pdf')) | |
if __name__ == '__main__': | |
main() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# multi-class cross-entropy loss with center-exclusive | |
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from torch.optim import lr_scheduler | |
torch.backends.cudnn.bencmark = True | |
import os, sys, random, datetime, time | |
from os.path import isdir, isfile, isdir, join, dirname, abspath | |
import argparse, datetime | |
import numpy as np | |
from PIL import Image | |
from scipy.io import savemat | |
import torchvision.datasets as datasets | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
import matplotlib.pyplot as plt | |
from utils import accuracy, test_lfw, AverageMeter, save_checkpoint, str2bool | |
from utils import CosineAnnelingLR | |
THIS_DIR = abspath(dirname(__file__)) | |
TMP_DIR = join(THIS_DIR, 'tmp') | |
if not isdir(TMP_DIR): | |
os.makedirs(TMP_DIR) | |
parser = argparse.ArgumentParser(description='PyTorch Implementation of HED.') | |
parser.add_argument('--bs', type=int, help='batch size', default=600) | |
# optimizer parameters | |
parser.add_argument('--lr', type=float, help='base learning rate', default=0.1) | |
parser.add_argument('--momentum', type=float, help='momentum', default=0.9) | |
parser.add_argument('--stepsize', type=float, help='step size (epoch)', default=18) | |
parser.add_argument('--gamma', type=float, help='gamma', default=0.1) | |
parser.add_argument('--wd', type=float, help='weight decay', default=5e-4) | |
parser.add_argument('--maxepoch', type=int, help='maximal training epoch', default=30) | |
# model parameters | |
parser.add_argument('--exclusive_weight', type=float, help='center exclusive loss weight', default=6) | |
parser.add_argument('--radius', type=float, help='radius', default=15) | |
# general parameters | |
parser.add_argument('--print_freq', type=int, help='print frequency', default=50) | |
parser.add_argument('--train', type=str, help='set to false to test lfw acc only', default="true") | |
parser.add_argument('--cuda', type=int, help='cuda', default=1) | |
parser.add_argument('--debug', type=str, help='debug mode', default='false') | |
parser.add_argument('--checkpoint', type=str, help='checkpoint prefix', default="center_exclusive") | |
parser.add_argument('--resume', type=str, help='checkpoint path', default=None) | |
# datasets | |
parser.add_argument('--casia', type=str, help='root folder of CASIA-WebFace dataset', default="data/CASIA-WebFace-112X96") | |
parser.add_argument('--num_class', type=int, help='num classes', default=10572) | |
parser.add_argument('--lfw', type=str, help='LFW dataset root folder', default="data/lfw-112X96") | |
parser.add_argument('--lfwlist', type=str, help='lfw image list', default='data/LFW_imagelist.txt') | |
args = parser.parse_args() | |
assert isfile(args.lfwlist) and isdir(args.lfw) | |
assert args.exclusive_weight > 0 | |
assert args.cuda == 1 | |
args.train = str2bool(args.train) | |
args.checkpoint = join(TMP_DIR, args.checkpoint) + "-exclusive_weight%.2f-radius%.1f-" % \ | |
(args.exclusive_weight, args.radius) + \ | |
datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") | |
print("Checkpoint directory: %s" % args.checkpoint) | |
if not isdir(args.checkpoint): | |
os.makedirs(args.checkpoint) | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
if args.train: | |
print("Pre-loading training data...") | |
train_dataset = datasets.ImageFolder( | |
args.casia, | |
transforms.Compose([ | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=args.bs, shuffle=True, | |
num_workers=12, pin_memory=True | |
) | |
print("Done!") | |
# transforms for LFW testing data | |
test_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
normalize | |
]) | |
# model and optimizer | |
from models import CenterExclusive | |
print("Loading model...") | |
model = CenterExclusive(num_class=args.num_class, norm_data=True, radius=args.radius) | |
print("Done!") | |
# optimizer related | |
criterion = nn.CrossEntropyLoss() | |
if True: | |
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momentum) | |
else: | |
# per-parameter options, see documentation https://pytorch.org/docs/stable/optim.html#per-parameter-options | |
# asign larger weight_decay and smaller lr to centers | |
optimizer = torch.optim.SGD([{'params': model.base.parameters()}, | |
{'params': model.fc6.parameters(), 'lr': args.lr, 'weight_decay': args.wd*5} | |
], lr=args.lr, momentum=args.momentum, weight_decay=args.wd) | |
# scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma) | |
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[16, 24, 28], gamma=args.gamma) | |
# scheduler = CosineAnnelingLR(optimizer, min_lr=0.01, max_lr=0.1, cycle_length=10) | |
if args.cuda: | |
print("Transporting model to GPU(s)...") | |
model.cuda() | |
print("Done!") | |
def train_epoch(train_loader, model, optimizer, epoch): | |
# recording | |
loss_cls = AverageMeter() | |
loss_exc = AverageMeter() | |
top1 = AverageMeter() | |
batch_time = AverageMeter() | |
train_record = np.zeros((len(train_loader), 4), np.float32) # loss, exc_loss, top1-acc, lr | |
# exclusive loss weight | |
#exclusive_weight = float(epoch + 1) ** 2 / float(1000) | |
exclusive_weight = args.exclusive_weight | |
# switch to train mode | |
model.train() | |
for batch_idx, (data, label) in enumerate(train_loader): | |
it = epoch * len(train_loader) + batch_idx | |
start_time = time.time() | |
if args.cuda: | |
data = data.cuda() | |
label = label.cuda(non_blocking=True) | |
prob, feature, center_exclusive_loss = model(data) | |
loss = criterion(prob, label) | |
# measure accuracy and record loss | |
prec1, prec5 = accuracy(prob, label, topk=(1, 5)) | |
loss_cls.update(loss.item(), data.size(0)) | |
loss_exc.update(center_exclusive_loss.item(), data.size(0)) | |
top1.update(prec1[0], data.size(0)) | |
# collect losses | |
loss = loss + exclusive_weight * center_exclusive_loss | |
# clear cached gradient | |
optimizer.zero_grad() | |
# backward gradient | |
loss.backward() | |
# update parameters | |
optimizer.step() | |
batch_time.update(time.time() - start_time) | |
if batch_idx % args.print_freq == 0: | |
print("Epoch %d/%d Batch %d/%d, (sec/batch: %.2fsec): loss_cls=%.3f (* 1), loss-exc=%.5f (* %.4f), acc1=%.3f, lr=%f" % \ | |
(epoch, args.maxepoch, batch_idx, len(train_loader), batch_time.val, loss_cls.val, | |
loss_exc.val, exclusive_weight, top1.val, scheduler.get_lr()[0])) | |
train_record[batch_idx, :] = np.array([loss_cls.avg, loss_exc.avg, top1.avg / float(100), scheduler.get_lr()[0]]) | |
return train_record | |
def main(): | |
lfw_acc_history = np.zeros((args.maxepoch, ), np.float32) | |
for epoch in range(args.maxepoch): | |
scheduler.step() # will adjust learning rate | |
if args.train: | |
if epoch == 0: | |
train_record = train_epoch(train_loader, model, optimizer, epoch) | |
else: | |
train_record = np.vstack((train_record, train_epoch(train_loader, model, optimizer, epoch))) | |
save_checkpoint({ | |
'epoch': epoch + 1, | |
'state_dict': model.state_dict(), | |
'optimizer' : optimizer.state_dict(), | |
}, filename=join(args.checkpoint, "epoch%d.pth" % epoch)) | |
# prepare data for testing | |
with open(args.lfwlist, 'r') as f: | |
imglist = f.readlines() | |
imglist = [join(args.lfw, i.rstrip()) for i in imglist] | |
lfw_acc_history[epoch] = test_lfw(model, imglist, test_transform, join(args.checkpoint, 'epoch%d' % epoch)) | |
print("Epoch %d best LFW accuracy is %.5f." % (epoch, lfw_acc_history.max())) | |
if args.train: | |
savemat(join(args.checkpoint, 'record(max-acc=%.5f).mat' % lfw_acc_history.max()), | |
dict({"train_record": train_record, | |
"lfw_acc_history": lfw_acc_history})) | |
fig, axes = plt.subplots(1, 5, figsize=(15, 3)) | |
for ax in axes: | |
ax.grid(True) | |
ax.hold(True) | |
axes[0].plot(train_record[:, 0], 'r') # loss cls | |
axes[0].set_title("CELoss") | |
axes[1].plot(train_record[:, 1], 'r') # loss exclusive | |
axes[1].set_title("ExLoss") | |
axes[2].plot(train_record[:, 2], 'r') # top1 acc | |
axes[2].set_title("Trn-Acc") | |
axes[3].plot(train_record[:, 3], 'r') # top1 acc | |
axes[3].set_title("LR") | |
axes[4].plot(lfw_acc_history.argmax(), lfw_acc_history.max(), 'r*', markersize=12) | |
axes[4].plot(lfw_acc_history, 'r') | |
axes[4].set_title("LFW-Acc") | |
else: | |
savemat(join(args.checkpoint + 'record(max-acc=%.5f).mat' % lfw_acc_history.max()), | |
dict({"lfw_acc_history": lfw_acc_history})) | |
plt.plot(lfw_acc_history) | |
plt.legend(['LFW-Accuracy (max=%.5f)' % lfw_acc_history.max()]) | |
plt.grid(True) | |
plt.savefig(join(args.checkpoint, 'radius%.1f-exweight%.1f.pdf' % (args.radius, args.exclusive_weight))) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment