-
-
Save redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 to your computer and use it in GitHub Desktop.
class Ralamb(Optimizer): | |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): | |
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) | |
self.buffer = [[None, None, None] for ind in range(10)] | |
super(Ralamb, self).__init__(params, defaults) | |
def __setstate__(self, state): | |
super(Ralamb, self).__setstate__(state) | |
def step(self, closure=None): | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
grad = p.grad.data.float() | |
if grad.is_sparse: | |
raise RuntimeError('Ralamb does not support sparse gradients') | |
p_data_fp32 = p.data.float() | |
state = self.state[p] | |
if len(state) == 0: | |
state['step'] = 0 | |
state['exp_avg'] = torch.zeros_like(p_data_fp32) | |
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) | |
else: | |
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) | |
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) | |
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |
beta1, beta2 = group['betas'] | |
# Decay the first and second moment running average coefficient | |
# m_t | |
exp_avg.mul_(beta1).add_(1 - beta1, grad) | |
# v_t | |
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
state['step'] += 1 | |
buffered = self.buffer[int(state['step'] % 10)] | |
if state['step'] == buffered[0]: | |
N_sma, radam_step_size = buffered[1], buffered[2] | |
else: | |
buffered[0] = state['step'] | |
beta2_t = beta2 ** state['step'] | |
N_sma_max = 2 / (1 - beta2) - 1 | |
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) | |
buffered[1] = N_sma | |
# more conservative since it's an approximated value | |
if N_sma >= 5: | |
radam_step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) | |
else: | |
radam_step_size = group['lr'] / (1 - beta1 ** state['step']) | |
buffered[2] = radam_step_size | |
if group['weight_decay'] != 0: | |
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) | |
# more conservative since it's an approximated value | |
radam_step = p_data_fp32.clone() | |
if N_sma >= 5: | |
denom = exp_avg_sq.sqrt().add_(group['eps']) | |
radam_step.addcdiv_(-radam_step_size, exp_avg, denom) | |
else: | |
radam_step.add_(-radam_step_size, exp_avg) | |
radam_norm = radam_step.pow(2).sum().sqrt() | |
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) | |
if weight_norm == 0 or radam_norm == 0: | |
trust_ratio = 1 | |
else: | |
trust_ratio = weight_norm / radam_norm | |
state['weight_norm'] = weight_norm | |
state['adam_norm'] = radam_norm | |
state['trust_ratio'] = trust_ratio | |
if N_sma >= 5: | |
p_data_fp32.addcdiv_(-radam_step_size * trust_ratio, exp_avg, denom) | |
else: | |
p_data_fp32.add_(-radam_step_size * trust_ratio, exp_avg) | |
p.data.copy_(p_data_fp32) | |
return loss |
For the gradient, I actually can picture cases where cosine similarity of gradients wouldn't help (if that's what you meant). Since it's a wrapper, we don't control the way the base optimizer performs the fast steps. So imagine that we have an optimizer giving a lot of importance to momentum or other inertia factor and that, the gradient of fast weights at each step is very similar (cosine dist).
You may have the projection into a 2D scatter plot of fast weights that is actually a curve rather than a line (optimizer magic). In this case, we would overshoot further than the last fast weights but in a direction that would not fit the curve. I would have to test the hypothesis (perhaps it doesn't occur that often, but my intuition is that optimizers, that have updates where gradient is not the only main player, might not work well with this approach)
While implementing Lookahead, I actually had a somehow similar idea @redknightlois, but more of a model forward/backward efficiency perspective.
I didn't consider overshooting, but selecting alpha automatically based on the variance of the fast updates:
let s
be the slow weights, and f1
, ...., fk
be the fast weights
- I evaluate the variance of
[f2 - f1, ..., fk - f(k-1)]
which would characterize the consistency of the fast updates' direction. - I would normalize it / squeeze it into [0, 1] and use
1 - squeezed_variance
as a synchronization rate (alpha)
The issue I had, is that memory-wise, I would have to store k-1
fast weights into memory to perform this. So, hopefully, less computation required, but higher memory usage (linearly with the synchronization period k
). I'll check if I can avoid that memory overhead and try to implement this.
I re-run notebooks with latest Ralamb and RangerLars and updated the main page with the results.
And you remove LARS from the actual base optimizer?
Yes, you take out LARS from the base optimizer.
If I get your idea correctly, you want to rescale the interpolation/update of slow weights in Lookahead based on a distance metric between the gradients of intermediate fast steps? (and suggesting the distance metric should be cosine similarity)
Yes, I suggested cosine similarity because you can picture it in your head in 2 dimensions. As you mentioned when the 'curve' is an actual curve, the similarity of the vectors is bad, therefore you have to use the normal schedule...
In 2D this is what I have in mind.
I didn't consider overshooting, but selecting alpha automatically based on the variance of the fast updates:
I did a LARS style version of Lookahead which uses the 'trust_ratio' between the norm of the fast and slow weights... and at 5 epochs you don't see such a noticeable change... but havent have the time to run it further OR with Annealing schedule. Which looks like is making a lot of difference based on the results just published by @mgrankin. Quick question, those numbers already include the change to use Mish activation?
EDIT: Quick caveat... when I am saying you can overshoot, what I mean is that you update the fast weights to overshooting and the slow weights are modified in such a way that you do not move into overshooting range.
Quick question, those numbers already include the change to use Mish activation?
No, I haven't looked into Mish yet.
@redknightlois. Hi, Thanks for the implementation. I have a question, how to save the Ralamb optimizer state_dict. There is no function for that. There is no load_state_dict function as well. Thanks
No, this was a prototype that I knocked up in a few hours time. Feel free to add those and I will update it.
Interesting thought @redknightlois
If I get your idea correctly, you want to rescale the interpolation/update of slow weights in Lookahead based on a distance metric between the gradients of intermediate fast steps? (and suggesting the distance metric should be cosine similarity)
And you remove LARS from the actual base optimizer?
I can draft a few options on the math of this and implement it, but I'll probably need something to run the training as I believe I don't have the hardware to meet the same training conditions as you guys!