-
-
Save redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 to your computer and use it in GitHub Desktop.
class Ralamb(Optimizer): | |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): | |
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) | |
self.buffer = [[None, None, None] for ind in range(10)] | |
super(Ralamb, self).__init__(params, defaults) | |
def __setstate__(self, state): | |
super(Ralamb, self).__setstate__(state) | |
def step(self, closure=None): | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
grad = p.grad.data.float() | |
if grad.is_sparse: | |
raise RuntimeError('Ralamb does not support sparse gradients') | |
p_data_fp32 = p.data.float() | |
state = self.state[p] | |
if len(state) == 0: | |
state['step'] = 0 | |
state['exp_avg'] = torch.zeros_like(p_data_fp32) | |
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) | |
else: | |
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) | |
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) | |
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |
beta1, beta2 = group['betas'] | |
# Decay the first and second moment running average coefficient | |
# m_t | |
exp_avg.mul_(beta1).add_(1 - beta1, grad) | |
# v_t | |
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
state['step'] += 1 | |
buffered = self.buffer[int(state['step'] % 10)] | |
if state['step'] == buffered[0]: | |
N_sma, radam_step_size = buffered[1], buffered[2] | |
else: | |
buffered[0] = state['step'] | |
beta2_t = beta2 ** state['step'] | |
N_sma_max = 2 / (1 - beta2) - 1 | |
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) | |
buffered[1] = N_sma | |
# more conservative since it's an approximated value | |
if N_sma >= 5: | |
radam_step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) | |
else: | |
radam_step_size = group['lr'] / (1 - beta1 ** state['step']) | |
buffered[2] = radam_step_size | |
if group['weight_decay'] != 0: | |
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) | |
# more conservative since it's an approximated value | |
radam_step = p_data_fp32.clone() | |
if N_sma >= 5: | |
denom = exp_avg_sq.sqrt().add_(group['eps']) | |
radam_step.addcdiv_(-radam_step_size, exp_avg, denom) | |
else: | |
radam_step.add_(-radam_step_size, exp_avg) | |
radam_norm = radam_step.pow(2).sum().sqrt() | |
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) | |
if weight_norm == 0 or radam_norm == 0: | |
trust_ratio = 1 | |
else: | |
trust_ratio = weight_norm / radam_norm | |
state['weight_norm'] = weight_norm | |
state['adam_norm'] = radam_norm | |
state['trust_ratio'] = trust_ratio | |
if N_sma >= 5: | |
p_data_fp32.addcdiv_(-radam_step_size * trust_ratio, exp_avg, denom) | |
else: | |
p_data_fp32.add_(-radam_step_size * trust_ratio, exp_avg) | |
p.data.copy_(p_data_fp32) | |
return loss |
And you remove LARS from the actual base optimizer?
Yes, you take out LARS from the base optimizer.
If I get your idea correctly, you want to rescale the interpolation/update of slow weights in Lookahead based on a distance metric between the gradients of intermediate fast steps? (and suggesting the distance metric should be cosine similarity)
Yes, I suggested cosine similarity because you can picture it in your head in 2 dimensions. As you mentioned when the 'curve' is an actual curve, the similarity of the vectors is bad, therefore you have to use the normal schedule...
In 2D this is what I have in mind.
I didn't consider overshooting, but selecting alpha automatically based on the variance of the fast updates:
I did a LARS style version of Lookahead which uses the 'trust_ratio' between the norm of the fast and slow weights... and at 5 epochs you don't see such a noticeable change... but havent have the time to run it further OR with Annealing schedule. Which looks like is making a lot of difference based on the results just published by @mgrankin. Quick question, those numbers already include the change to use Mish activation?
EDIT: Quick caveat... when I am saying you can overshoot, what I mean is that you update the fast weights to overshooting and the slow weights are modified in such a way that you do not move into overshooting range.
Quick question, those numbers already include the change to use Mish activation?
No, I haven't looked into Mish yet.
@redknightlois. Hi, Thanks for the implementation. I have a question, how to save the Ralamb optimizer state_dict. There is no function for that. There is no load_state_dict function as well. Thanks
No, this was a prototype that I knocked up in a few hours time. Feel free to add those and I will update it.
I re-run notebooks with latest Ralamb and RangerLars and updated the main page with the results.