Skip to content

Instantly share code, notes, and snippets.

@NegatioN
Last active July 31, 2018 07:51
Show Gist options
  • Save NegatioN/35d0bc81aef6c8775f10591121778138 to your computer and use it in GitHub Desktop.
Save NegatioN/35d0bc81aef6c8775f10591121778138 to your computer and use it in GitHub Desktop.
Cirular LR
# Circular LR as implemented in fast.ai, however this is not dependent on all the interals of it
class CircularLR:
def __init__(self, optimizer, nb, div=10, pct=10, momentums=None):
self.nb,self.div,self.pct = nb,div,pct
self.cycle_nb = int(nb * (1-pct/100) / 2)
self.opt = optimizer
self.init_lr = self.opt.param_groups[0]['lr']
if momentums is not None:
self.moms = momentums
def on_begin(self):
self.cycle_iter,self.cycle_count=0,0
self.update_lr()
self.update_mom()
def on_batch_end(self):
self.update_lr()
self.update_mom()
def update_mom(self):
self.set_mom(self.calc_mom())
def update_lr(self):
new_lr = self.calc_lr(self.init_lr)
self.set_lr(new_lr)
def set_lr(self, lr):
for pg in self.opt.param_groups:
pg['lr'] = lr
def set_mom(self,momentum):
if 'betas' in self.opt.param_groups[0]:
for pg in self.opt.param_groups: pg['betas'] = (momentum, pg['betas'][1])
else:
for pg in self.opt.param_groups: pg['momentum'] = momentum
def calc_lr(self, init_lrs):
if self.cycle_iter>2 * self.cycle_nb:
pct = (self.cycle_iter - 2*self.cycle_nb)/(self.nb - 2*self.cycle_nb)
res = init_lrs * (1 + (pct * (1-100)/100)) / self.div
elif self.cycle_iter>self.cycle_nb:
pct = 1 - (self.cycle_iter - self.cycle_nb)/self.cycle_nb
res = init_lrs * (1 + pct*(self.div-1)) / self.div
else:
pct = self.cycle_iter/self.cycle_nb
res = init_lrs * (1 + pct*(self.div-1)) / self.div
self.cycle_iter += 1
if self.cycle_iter==self.nb:
self.cycle_iter = 0
if self.on_cycle_end: self.on_cycle_end(self, self.cycle_count)
self.cycle_count += 1
return res
def calc_mom(self):
if self.cycle_iter>2*self.cycle_nb:
res = self.moms[0]
elif self.cycle_iter>self.cycle_nb:
pct = 1 - (self.cycle_iter - self.cycle_nb)/self.cycle_nb
res = self.moms[0] + pct * (self.moms[1] - self.moms[0])
else:
pct = self.cycle_iter/self.cycle_nb
res = self.moms[0] + pct * (self.moms[1] - self.moms[0])
return res
div, pct = 10, 13.68
moms = 0.95,0.85 #use_clr_beta[2:] if len(use_clr_beta) > 3 else None
cycle_len = 30
optimizer = torch.optim.Adam(model.parameters(), lr=10**(-1))
callbacks = [CircularLR(optimizer, len(train_loader)*cycle_len, div=div, pct=pct, momentums=moms)]
# in training loop call callbacks[0].on_begin() for every epoch, and callbacks[0].on_batch_end() for every iteration
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment