Skip to content

Instantly share code, notes, and snippets.

@redknightlois
Last active August 9, 2023 20:50
Show Gist options
  • Save redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 to your computer and use it in GitHub Desktop.
Save redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 to your computer and use it in GitHub Desktop.
Ralamb optimizer (RAdam + LARS trick)
class Ralamb(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
super(Ralamb, self).__init__(params, defaults)
def __setstate__(self, state):
super(Ralamb, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('Ralamb does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(1 - beta1, grad)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
state['step'] += 1
buffered = self.buffer[int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, radam_step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
radam_step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
radam_step_size = group['lr'] / (1 - beta1 ** state['step'])
buffered[2] = radam_step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
# more conservative since it's an approximated value
radam_step = p_data_fp32.clone()
if N_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps'])
radam_step.addcdiv_(-radam_step_size, exp_avg, denom)
else:
radam_step.add_(-radam_step_size, exp_avg)
radam_norm = radam_step.pow(2).sum().sqrt()
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
if weight_norm == 0 or radam_norm == 0:
trust_ratio = 1
else:
trust_ratio = weight_norm / radam_norm
state['weight_norm'] = weight_norm
state['adam_norm'] = radam_norm
state['trust_ratio'] = trust_ratio
if N_sma >= 5:
p_data_fp32.addcdiv_(-radam_step_size * trust_ratio, exp_avg, denom)
else:
p_data_fp32.add_(-radam_step_size * trust_ratio, exp_avg)
p.data.copy_(p_data_fp32)
return loss
@redknightlois
Copy link
Author

redknightlois commented Aug 30, 2019

I am still waiting to get some results on the 20 epoch run, but the 5 epoch runs on Ralars, Ralamb (v4), Ralamb (v1), RAdam and Lookahead show the following:

  • Lookahead and RAdam are on top.
  • Ralamb v4
  • Ralars and Ralamb v1 are comparable

This is what we know so far:

  • V1 was flawed (which somehow got good results) and the AFAIK (@mgrankin should comment on that) is the results on the main page.
  • V2 and V3, they are even worse... These guys do not converge at all when running on isolation (which makes sense because we are not calculating the proper complete step before taking the actual step forward).
  • V4 those differences are completely unintended, they are bugs. However, if what @frgfm is saying is right; we probably should dig there because there may be something we don't know yet about the behavior.
  • I was not able to locally reproduce the results that @mgrankin repo has on the readme. (weird I know)

I have been working on a variant of Lookahead because it is as if LARS style optimizers do not play along with Lookahead (reason why Ranger beat all the others). But I am stuck, I need someone with a stronger math background to figure out the formula.

The idea is to incorporate the basic idea of LARS directly into Lookahead. Put it this way, you want to probably overshoot if the gradient has generally the same direction in all intermediate steps... therefore if you take t0, ..., tk/2, ... tk the vector from to->tk/2 and t0->tk should have a cosine similarity of 1... in that case you are probably looking into a deep dive kind of landscape... however if that cosine similarity is going to 0, we are probably in exploration mode and the buddy should stay a bit away in case the explorer trips into a hole (local minima).

If cosine similarity is 1, there is no harm into overshooting (trust ratio > 1) even though the buddy is going to be updated proportional to alpha in the fast direction. While lookahead today updates the slow but doesn't change the fast gradient, I am saying that from the point of view of the buddy, we could push the fast gradient in overshooting mode, because he believes it is safe to do so from his point of view. And if it fails, you just wasted k batches.

What do you think?

@frgfm
Copy link

frgfm commented Aug 30, 2019

Interesting thought @redknightlois
If I get your idea correctly, you want to rescale the interpolation/update of slow weights in Lookahead based on a distance metric between the gradients of intermediate fast steps? (and suggesting the distance metric should be cosine similarity)

And you remove LARS from the actual base optimizer?

I can draft a few options on the math of this and implement it, but I'll probably need something to run the training as I believe I don't have the hardware to meet the same training conditions as you guys!

@frgfm
Copy link

frgfm commented Aug 30, 2019

For the gradient, I actually can picture cases where cosine similarity of gradients wouldn't help (if that's what you meant). Since it's a wrapper, we don't control the way the base optimizer performs the fast steps. So imagine that we have an optimizer giving a lot of importance to momentum or other inertia factor and that, the gradient of fast weights at each step is very similar (cosine dist).

You may have the projection into a 2D scatter plot of fast weights that is actually a curve rather than a line (optimizer magic). In this case, we would overshoot further than the last fast weights but in a direction that would not fit the curve. I would have to test the hypothesis (perhaps it doesn't occur that often, but my intuition is that optimizers, that have updates where gradient is not the only main player, might not work well with this approach)

While implementing Lookahead, I actually had a somehow similar idea @redknightlois, but more of a model forward/backward efficiency perspective.
I didn't consider overshooting, but selecting alpha automatically based on the variance of the fast updates:
let s be the slow weights, and f1, ...., fk be the fast weights

  • I evaluate the variance of [f2 - f1, ..., fk - f(k-1)] which would characterize the consistency of the fast updates' direction.
  • I would normalize it / squeeze it into [0, 1] and use 1 - squeezed_variance as a synchronization rate (alpha)

The issue I had, is that memory-wise, I would have to store k-1 fast weights into memory to perform this. So, hopefully, less computation required, but higher memory usage (linearly with the synchronization period k). I'll check if I can avoid that memory overhead and try to implement this.

@mgrankin
Copy link

I re-run notebooks with latest Ralamb and RangerLars and updated the main page with the results.

@redknightlois
Copy link
Author

redknightlois commented Aug 30, 2019

And you remove LARS from the actual base optimizer?

Yes, you take out LARS from the base optimizer.

If I get your idea correctly, you want to rescale the interpolation/update of slow weights in Lookahead based on a distance metric between the gradients of intermediate fast steps? (and suggesting the distance metric should be cosine similarity)

Yes, I suggested cosine similarity because you can picture it in your head in 2 dimensions. As you mentioned when the 'curve' is an actual curve, the similarity of the vectors is bad, therefore you have to use the normal schedule...

In 2D this is what I have in mind.
image

I didn't consider overshooting, but selecting alpha automatically based on the variance of the fast updates:

I did a LARS style version of Lookahead which uses the 'trust_ratio' between the norm of the fast and slow weights... and at 5 epochs you don't see such a noticeable change... but havent have the time to run it further OR with Annealing schedule. Which looks like is making a lot of difference based on the results just published by @mgrankin. Quick question, those numbers already include the change to use Mish activation?

EDIT: Quick caveat... when I am saying you can overshoot, what I mean is that you update the fast weights to overshooting and the slow weights are modified in such a way that you do not move into overshooting range.

@mgrankin
Copy link

Quick question, those numbers already include the change to use Mish activation?

No, I haven't looked into Mish yet.

@VirajBagal
Copy link

@redknightlois. Hi, Thanks for the implementation. I have a question, how to save the Ralamb optimizer state_dict. There is no function for that. There is no load_state_dict function as well. Thanks

@redknightlois
Copy link
Author

No, this was a prototype that I knocked up in a few hours time. Feel free to add those and I will update it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment