Skip to content

Instantly share code, notes, and snippets.

@NegatioN
Last active September 14, 2018 11:25
Show Gist options
  • Save NegatioN/07edc229f9d668b2b366528d94500f49 to your computer and use it in GitHub Desktop.
Save NegatioN/07edc229f9d668b2b366528d94500f49 to your computer and use it in GitHub Desktop.
find learning rate from fast.ai
import matplotlib.pyplot as plt
%matplotlib inline
def find_lr(net, criterion, optimizer, trn_loader, init_value = 1e-8, final_value=10., beta = 0.98):
num = len(trn_loader)-1
mult = (final_value / init_value) ** (1/num)
lr = init_value
optimizer.param_groups[0]['lr'] = lr
avg_loss = 0.
best_loss = 0.
batch_num = 0
losses = []
log_lrs = []
for data in trn_loader:
batch_num += 1
#As before, get the loss for this mini-batch of inputs/outputs
inputs,labels = data
inputs, labels = Variable(inputs.to(device)), Variable(labels.to(device))
optimizer.zero_grad()
pos, neg = net(labels, inputs)
loss = criterion(pos.view(-1, 1), neg)
#Compute the smoothed loss
avg_loss = beta * avg_loss + (1-beta) *loss.data[0]
smoothed_loss = avg_loss / (1 - beta**batch_num)
#Stop if the loss is exploding
if batch_num > 1 and smoothed_loss > 4 * best_loss:
return log_lrs, losses
#Record the best loss
if smoothed_loss < best_loss or batch_num==1:
best_loss = smoothed_loss
#Store the values
losses.append(smoothed_loss)
log_lrs.append(torch.log10(torch.FloatTensor([lr]).to(device)))
#Do the SGD step
loss.backward()
optimizer.step()
#Update the lr for the next step
lr *= mult
optimizer.param_groups[0]['lr'] = lr
return log_lrs, losses
'''
Example run. Select highest LR before graph looks erratic
model = BasicEmbModel(len(occupations), len(locations), len(industries), len(job_tags), len(user_ids), 25).to(device)
occu_warp = partial(warp_loss, num_labels=torch.FloatTensor([len(adids)]).to(device), device=device, limit_grad=False)
logs,losses = find_lr(model, occu_warp, torch.optim.RMSprop(model.parameters(), lr=0.05), train_loader, init_value=0.001, final_value=1000)
m = -5
plt.plot(logs[10:m],losses[10:m])
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment