Last active
August 9, 2023 20:50
-
-
Save redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 to your computer and use it in GitHub Desktop.
Ralamb optimizer (RAdam + LARS trick)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Ralamb(Optimizer): | |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): | |
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) | |
self.buffer = [[None, None, None] for ind in range(10)] | |
super(Ralamb, self).__init__(params, defaults) | |
def __setstate__(self, state): | |
super(Ralamb, self).__setstate__(state) | |
def step(self, closure=None): | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
grad = p.grad.data.float() | |
if grad.is_sparse: | |
raise RuntimeError('Ralamb does not support sparse gradients') | |
p_data_fp32 = p.data.float() | |
state = self.state[p] | |
if len(state) == 0: | |
state['step'] = 0 | |
state['exp_avg'] = torch.zeros_like(p_data_fp32) | |
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) | |
else: | |
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) | |
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) | |
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |
beta1, beta2 = group['betas'] | |
# Decay the first and second moment running average coefficient | |
# m_t | |
exp_avg.mul_(beta1).add_(1 - beta1, grad) | |
# v_t | |
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
state['step'] += 1 | |
buffered = self.buffer[int(state['step'] % 10)] | |
if state['step'] == buffered[0]: | |
N_sma, radam_step_size = buffered[1], buffered[2] | |
else: | |
buffered[0] = state['step'] | |
beta2_t = beta2 ** state['step'] | |
N_sma_max = 2 / (1 - beta2) - 1 | |
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) | |
buffered[1] = N_sma | |
# more conservative since it's an approximated value | |
if N_sma >= 5: | |
radam_step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) | |
else: | |
radam_step_size = group['lr'] / (1 - beta1 ** state['step']) | |
buffered[2] = radam_step_size | |
if group['weight_decay'] != 0: | |
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) | |
# more conservative since it's an approximated value | |
radam_step = p_data_fp32.clone() | |
if N_sma >= 5: | |
denom = exp_avg_sq.sqrt().add_(group['eps']) | |
radam_step.addcdiv_(-radam_step_size, exp_avg, denom) | |
else: | |
radam_step.add_(-radam_step_size, exp_avg) | |
radam_norm = radam_step.pow(2).sum().sqrt() | |
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) | |
if weight_norm == 0 or radam_norm == 0: | |
trust_ratio = 1 | |
else: | |
trust_ratio = weight_norm / radam_norm | |
state['weight_norm'] = weight_norm | |
state['adam_norm'] = radam_norm | |
state['trust_ratio'] = trust_ratio | |
if N_sma >= 5: | |
p_data_fp32.addcdiv_(-radam_step_size * trust_ratio, exp_avg, denom) | |
else: | |
p_data_fp32.add_(-radam_step_size * trust_ratio, exp_avg) | |
p.data.copy_(p_data_fp32) | |
return loss |
Quick question, those numbers already include the change to use Mish activation?
No, I haven't looked into Mish yet.
@redknightlois. Hi, Thanks for the implementation. I have a question, how to save the Ralamb optimizer state_dict. There is no function for that. There is no load_state_dict function as well. Thanks
No, this was a prototype that I knocked up in a few hours time. Feel free to add those and I will update it.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yes, you take out LARS from the base optimizer.
Yes, I suggested cosine similarity because you can picture it in your head in 2 dimensions. As you mentioned when the 'curve' is an actual curve, the similarity of the vectors is bad, therefore you have to use the normal schedule...
In 2D this is what I have in mind.

I did a LARS style version of Lookahead which uses the 'trust_ratio' between the norm of the fast and slow weights... and at 5 epochs you don't see such a noticeable change... but havent have the time to run it further OR with Annealing schedule. Which looks like is making a lot of difference based on the results just published by @mgrankin. Quick question, those numbers already include the change to use Mish activation?
EDIT: Quick caveat... when I am saying you can overshoot, what I mean is that you update the fast weights to overshooting and the slow weights are modified in such a way that you do not move into overshooting range.