Last active
September 14, 2018 11:25
-
-
Save NegatioN/07edc229f9d668b2b366528d94500f49 to your computer and use it in GitHub Desktop.
find learning rate from fast.ai
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
%matplotlib inline | |
def find_lr(net, criterion, optimizer, trn_loader, init_value = 1e-8, final_value=10., beta = 0.98): | |
num = len(trn_loader)-1 | |
mult = (final_value / init_value) ** (1/num) | |
lr = init_value | |
optimizer.param_groups[0]['lr'] = lr | |
avg_loss = 0. | |
best_loss = 0. | |
batch_num = 0 | |
losses = [] | |
log_lrs = [] | |
for data in trn_loader: | |
batch_num += 1 | |
#As before, get the loss for this mini-batch of inputs/outputs | |
inputs,labels = data | |
inputs, labels = Variable(inputs.to(device)), Variable(labels.to(device)) | |
optimizer.zero_grad() | |
pos, neg = net(labels, inputs) | |
loss = criterion(pos.view(-1, 1), neg) | |
#Compute the smoothed loss | |
avg_loss = beta * avg_loss + (1-beta) *loss.data[0] | |
smoothed_loss = avg_loss / (1 - beta**batch_num) | |
#Stop if the loss is exploding | |
if batch_num > 1 and smoothed_loss > 4 * best_loss: | |
return log_lrs, losses | |
#Record the best loss | |
if smoothed_loss < best_loss or batch_num==1: | |
best_loss = smoothed_loss | |
#Store the values | |
losses.append(smoothed_loss) | |
log_lrs.append(torch.log10(torch.FloatTensor([lr]).to(device))) | |
#Do the SGD step | |
loss.backward() | |
optimizer.step() | |
#Update the lr for the next step | |
lr *= mult | |
optimizer.param_groups[0]['lr'] = lr | |
return log_lrs, losses | |
''' | |
Example run. Select highest LR before graph looks erratic | |
model = BasicEmbModel(len(occupations), len(locations), len(industries), len(job_tags), len(user_ids), 25).to(device) | |
occu_warp = partial(warp_loss, num_labels=torch.FloatTensor([len(adids)]).to(device), device=device, limit_grad=False) | |
logs,losses = find_lr(model, occu_warp, torch.optim.RMSprop(model.parameters(), lr=0.05), train_loader, init_value=0.001, final_value=1000) | |
m = -5 | |
plt.plot(logs[10:m],losses[10:m]) | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment