Created
January 26, 2020 19:20
-
-
Save mjamroz/d6f4aacbf46442bace75e67b0464e2c5 to your computer and use it in GitHub Desktop.
find learning rate using mxnet (https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/training/learning_rates/learning_rate_finder.html)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
import argparse, os | |
from matplotlib import pyplot as plt | |
from gluoncv.model_zoo import get_model | |
#mx.random.seed(42) | |
# CLI | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Train a model for image classification.') | |
parser.add_argument('--classes', type=str, default=1000, | |
help='number of classes') | |
parser.add_argument('--rec-train', type=str, default='images_train.rec', | |
help='the training data') | |
parser.add_argument('--rec-train-idx', type=str, default='images_train.idx', | |
help='the index of training data') | |
parser.add_argument('--batch-size', type=int, default=32, | |
help='training batch size per device (CPU/GPU).') | |
parser.add_argument('--dtype', type=str, default='float32', | |
help='data type for training. default is float32') | |
parser.add_argument('--num-gpus', type=int, default=0, | |
help='number of gpus to use.') | |
parser.add_argument('--momentum', type=float, default=0.9, | |
help='momentum value for optimizer, default is 0.9.') | |
parser.add_argument('--wd', type=float, default=0.0001, | |
help='weight decay rate. default is 0.0001.') | |
parser.add_argument('--model', type=str, required=True, | |
help='type of model to use. see vision_model for options.') | |
parser.add_argument('--no-wd', action='store_true', | |
help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.') | |
opt = parser.parse_args() | |
return opt | |
def prep_net(classes, context, opt): | |
net = get_model(opt.model, pretrained=True) | |
with net.name_scope(): | |
net.output = mx.gluon.nn.Dense(classes) | |
net.output.initialize(mx.init.Xavier(), ctx=context) | |
net.collect_params().reset_ctx(context) | |
if opt.no_wd: | |
for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): | |
v.wd_mult = 0.0 | |
net.hybridize(static_alloc=True, static_shape=True) | |
return net | |
class Learner(): | |
def __init__(self, net, data_loader, ctx, opt): | |
""" | |
:param net: network (mx.gluon.Block) | |
:param data_loader: training data loader (mx.gluon.data.DataLoader) | |
:param ctx: context (mx.gpu or mx.cpu) | |
""" | |
self.net = net | |
self.opt = opt | |
self.data_loader = data_loader | |
self.ctx = ctx | |
#self.net.initialize(mx.init.Xavier(), ctx=self.ctx) | |
self.net.initialize(mx.init.MSRAPrelu(), ctx=ctx) | |
self.loss_fn = mx.gluon.loss.SoftmaxCrossEntropyLoss() | |
optimizer_params = {'learning_rate': .001, 'wd': self.opt.wd, 'momentum': self.opt.momentum} | |
if self.opt.dtype != 'float32': | |
optimizer_params['multi_precision'] = True | |
self.trainer = mx.gluon.Trainer(net.collect_params(), 'nag', optimizer_params) | |
def iteration(self, lr=None, take_step=True): | |
""" | |
:param lr: learning rate to use for iteration (float) | |
:param take_step: take trainer step to update weights (boolean) | |
:return: iteration loss (float) | |
""" | |
# Update learning rate if different this iteration | |
if lr and (lr != self.trainer.learning_rate): | |
self.trainer.set_learning_rate(lr) | |
# Get next batch, and move context (e.g. to GPU if set) | |
try: | |
bt = next(self.data_loader) | |
except StopIteration: | |
self.data_loader.reset() | |
bt = next(self.data_loader) | |
data = mx.gluon.utils.split_and_load(bt.data[0], ctx_list=self.ctx, batch_axis=0) | |
label = mx.gluon.utils.split_and_load(bt.label[0], ctx_list=self.ctx, batch_axis=0) | |
# Standard forward and backward pass | |
with mx.autograd.record(): | |
outputs = [self.net(X.astype(self.opt.dtype, copy=False)) for X in data] | |
loss = [self.loss_fn(yhat, y.astype(self.opt.dtype, copy=False)) for yhat, y in zip(outputs, label)] | |
for l in loss: | |
l.backward() | |
# Update parameters | |
if take_step: self.trainer.step(data[0].shape[0]) | |
# Set and return loss. | |
ls = 0.0 | |
for l in loss: | |
ls += mx.nd.mean(l).asscalar() | |
self.iteration_loss = ls | |
return self.iteration_loss | |
def close(self): | |
# Close open iterator and associated workers | |
self.data_loader.shutdown() | |
class LRFinder(): | |
def __init__(self, learner): | |
""" | |
:param learner: able to take single iteration with given learning rate and return loss | |
and save and load parameters of the network (Learner) | |
""" | |
self.learner = learner | |
def find(self, lr_start=1e-6, lr_multiplier=1.1, smoothing=0.3): | |
""" | |
:param lr_start: learning rate to start search (float) | |
:param lr_multiplier: factor the learning rate is multiplied by at each step of search (float) | |
:param smoothing: amount of smoothing applied to loss for stopping criteria (float) | |
:return: learning rate and loss pairs (list of (float, float) tuples) | |
""" | |
# Used to initialize weights; pass data, but don't take step. | |
# Would expect for new model with lazy weight initialization | |
self.learner.iteration(take_step=False) | |
# Used to initialize trainer (if no step has been taken) | |
if not self.learner.trainer._kv_initialized: | |
self.learner.trainer._init_kvstore() | |
# Store params and optimizer state for restore after lr_finder procedure | |
# Useful for applying the method partway through training, not just for initialization of lr. | |
self.learner.net.save_parameters("lr_finder.params") | |
self.learner.trainer.save_states("lr_finder.state") | |
lr = lr_start | |
self.results = [] # List of (lr, loss) tuples | |
stopping_criteria = LRFinderStoppingCriteria(smoothing) | |
while True: | |
# Run iteration, and block until loss is calculated. | |
loss = self.learner.iteration(lr) | |
self.results.append((lr, loss)) | |
if stopping_criteria(loss): | |
break | |
lr = lr * lr_multiplier | |
# Restore params (as finder changed them) | |
self.learner.net.load_parameters("lr_finder.params", ctx=self.learner.ctx) | |
self.learner.trainer.load_states("lr_finder.state") | |
return self.results | |
def plot(self): | |
lrs = [e[0] for e in self.results] | |
losses = [e[1] for e in self.results] | |
plt.figure(figsize=(6,8)) | |
plt.scatter(lrs, losses) | |
plt.xlabel("Learning Rate") | |
plt.ylabel("Loss") | |
plt.xscale('log') | |
plt.yscale('log') | |
axes = plt.gca() | |
axes.set_xlim([lrs[0], lrs[-1]]) | |
y_lower = min(losses) * 0.8 | |
y_upper = losses[0] * 4 | |
axes.set_ylim([y_lower, y_upper]) | |
plt.savefig("find_lr.png") | |
plt.show() | |
class LRFinderStoppingCriteria(): | |
def __init__(self, smoothing=0.3, min_iter=20): | |
""" | |
:param smoothing: applied to running mean which is used for thresholding (float) | |
:param min_iter: minimum number of iterations before early stopping can occur (int) | |
""" | |
self.smoothing = smoothing | |
self.min_iter = min_iter | |
self.first_loss = None | |
self.running_mean = None | |
self.counter = 0 | |
def __call__(self, loss): | |
""" | |
:param loss: from single iteration (float) | |
:return: indicator to stop (boolean) | |
""" | |
self.counter += 1 | |
if self.first_loss is None: | |
self.first_loss = loss | |
if self.running_mean is None: | |
self.running_mean = loss | |
else: | |
self.running_mean = ((1 - self.smoothing) * loss) + (self.smoothing * self.running_mean) | |
return (self.running_mean > self.first_loss * 2) and (self.counter >= self.min_iter) | |
def get_data_rec(rec_train, rec_train_idx, batch_size): | |
rec_train = os.path.expanduser(rec_train) | |
rec_train_idx = os.path.expanduser(rec_train_idx) | |
input_size = 224 | |
mean_rgb = [123.68, 116.779, 103.939] | |
std_rgb = [58.393, 57.12, 57.375] | |
train_data = mx.io.ImageRecordIter( | |
path_imgrec = rec_train, | |
path_imgidx = rec_train_idx, | |
preprocess_threads = 4, | |
shuffle = True, | |
batch_size = batch_size, | |
data_shape = (3, input_size, input_size), | |
mean_r = mean_rgb[0], | |
mean_g = mean_rgb[1], | |
mean_b = mean_rgb[2], | |
std_r = std_rgb[0], | |
std_g = std_rgb[1], | |
std_b = std_rgb[2], | |
rand_crop = True | |
) | |
return train_data | |
opt = parse_args() | |
context = [mx.gpu(i) for i in range(opt.num_gpus)] if opt.num_gpus > 0 else [mx.cpu()] | |
net = prep_net(opt.classes, context, opt) | |
train_data = get_data_rec(opt.rec_train, opt.rec_train_idx, opt.batch_size) | |
learner = Learner(net=net, data_loader=train_data, ctx=context, opt=opt) | |
lr_finder = LRFinder(learner) | |
lr_finder.find(lr_start=1e-6) | |
lr_finder.plot() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment