-
-
Save redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 to your computer and use it in GitHub Desktop.
class Ralamb(Optimizer): | |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): | |
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) | |
self.buffer = [[None, None, None] for ind in range(10)] | |
super(Ralamb, self).__init__(params, defaults) | |
def __setstate__(self, state): | |
super(Ralamb, self).__setstate__(state) | |
def step(self, closure=None): | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
grad = p.grad.data.float() | |
if grad.is_sparse: | |
raise RuntimeError('Ralamb does not support sparse gradients') | |
p_data_fp32 = p.data.float() | |
state = self.state[p] | |
if len(state) == 0: | |
state['step'] = 0 | |
state['exp_avg'] = torch.zeros_like(p_data_fp32) | |
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) | |
else: | |
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) | |
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) | |
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |
beta1, beta2 = group['betas'] | |
# Decay the first and second moment running average coefficient | |
# m_t | |
exp_avg.mul_(beta1).add_(1 - beta1, grad) | |
# v_t | |
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
state['step'] += 1 | |
buffered = self.buffer[int(state['step'] % 10)] | |
if state['step'] == buffered[0]: | |
N_sma, radam_step_size = buffered[1], buffered[2] | |
else: | |
buffered[0] = state['step'] | |
beta2_t = beta2 ** state['step'] | |
N_sma_max = 2 / (1 - beta2) - 1 | |
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) | |
buffered[1] = N_sma | |
# more conservative since it's an approximated value | |
if N_sma >= 5: | |
radam_step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) | |
else: | |
radam_step_size = group['lr'] / (1 - beta1 ** state['step']) | |
buffered[2] = radam_step_size | |
if group['weight_decay'] != 0: | |
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) | |
# more conservative since it's an approximated value | |
radam_step = p_data_fp32.clone() | |
if N_sma >= 5: | |
denom = exp_avg_sq.sqrt().add_(group['eps']) | |
radam_step.addcdiv_(-radam_step_size, exp_avg, denom) | |
else: | |
radam_step.add_(-radam_step_size, exp_avg) | |
radam_norm = radam_step.pow(2).sum().sqrt() | |
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) | |
if weight_norm == 0 or radam_norm == 0: | |
trust_ratio = 1 | |
else: | |
trust_ratio = weight_norm / radam_norm | |
state['weight_norm'] = weight_norm | |
state['adam_norm'] = radam_norm | |
state['trust_ratio'] = trust_ratio | |
if N_sma >= 5: | |
p_data_fp32.addcdiv_(-radam_step_size * trust_ratio, exp_avg, denom) | |
else: | |
p_data_fp32.add_(-radam_step_size * trust_ratio, exp_avg) | |
p.data.copy_(p_data_fp32) | |
return loss |
# Note: Here are two choices for scaling function \phi(z)
# minmax: \phi(z) = min(max(z, \gamma_l), \gamma_u)
# identity: \phi(z) = z
# The authors does not mention what is \gamma_l and \gamma_u
# UPDATE: after asking authors, they provide me the code below.
# ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where(
# math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)
The authors might not use that clipping. The code they provided is equivalent to the following code:
if weight_norm == 0 or radam_norm == 0:
trust_ratio = 1
else:
trust_ratio = weight_norm / radam_norm
Update:
https://github.com/tensorflow/tensorflow/blob/66bb198acad21260038805e02960b791cb467177/tensorflow/contrib/opt/python/training/lars_optimizer.py#L108-L114
TensorFlow's LARSOptimizer has a similar code:
trust_ratio = array_ops.where(
math_ops.greater(w_norm, 0),
array_ops.where(
math_ops.greater(g_norm, 0),
(self._eeta * w_norm /
(g_norm + self._weight_decay * w_norm + self._epsilon)), 1.0),
1.0)
Update2:
https://github.com/borisgin/nvcaffe/blob/8896d3303cfdfb575923d6bb108a99ebda728855/src/caffe/solvers/sgd_solver.cpp#L321-L326
Caffe's LARC uses the same method:
float rate = 1.F;
if (w_norm > 0.F && wgrad_norm > 0.F) {
//float weight_decay = this->param_.weight_decay();
//rate = gw_ratio * w_norm / (wgrad_norm + weight_decay * w_norm);
rate = gw_ratio * w_norm / wgrad_norm ;
}
I suspect the authors didn't provide a code for that clipping.
This is just anecdotal, but I've used the updated version with my own dataset, and no clipping is better than clipping to 10.
Thanks for releasing this code Federico, and to all contributors for making it better!
@mgrankin Have you seen in the 80 annealing schedule such a huge back and forth in loss?
@redknightlois Haven't run 80 so far, but it's worth digging those swings.
We have definitely did something wrong here... Version 3 is broken, I am not entirely sure it converges at all. Back to the drawing board. @r1ckya Any idea what might be wrong?
@redknightlois yes same here
I went back and reimplemented both papers (LARS optional, and the above-mentioned clipping being optional as well) and it seems to be working for now: https://github.com/frgfm/Holocron/blob/master/holocron/optim/radam.py
@frgfm you do not use bias correction on step size, I thought about removing that to, but didn't have time to test it yet (in the paper they ditched that too, at least last time I checked https://arxiv.org/abs/1904.00962v3 ), maybe that affects something in a weird way.
I am testing it now and got better results, the problem was that we were doing it wrong. Real wrong... will update in a few minutes
python train.py --run 20 --woof 0 --size 128 --bs 64 --mixup 0 --sa 0 --epoch 5 --lr 1e-2 --gpu 0 --opt ralamb
lr: 0.01; eff_lr: 0.01; size: 128; alpha: 0.99; mom: 0.9; eps: 1e-06
\.fastai\data\imagenette-160
epoch train_loss valid_loss accuracy top_k_accuracy time
0 1.796517 1.931205 0.484000 0.890000 02:26
1 1.509954 1.727263 0.536000 0.912000 02:11
2 1.276914 1.195526 0.734000 0.958000 02:11
3 1.128250 1.019652 0.792000 0.978000 02:13
4 1.009607 0.911938 0.834000 0.984000 02:11
@r1ckya If you are referring to the bias-correction of first moment, I noticed I had an unpushed commit!
https://github.com/frgfm/Holocron/blob/master/holocron/optim/radam.py#L90-L91
For the rest, I actually stuck with the idea of the initial LARS paper https://arxiv.org/pdf/1708.03888.pdf. I'll check the difference in performance with the paper you mention
Thanks @redknightlois!
@frgfm I cannot make yours to converge on ImageWoof.
@redknightlois Sorry about that, the unpushed commit had an issue in it (only correcting bias for first but not second moment)
I updated it, but even with the above, I'm not training on ImageWoof but it's performing quite well. Not using weight_decay in my case, and had to scale up the learning rate but it's definitely converging!
But your first implementation is still working for me.
It's just when I put the changes you both mentioned earlier, I cannot make it work, for unclear reasons
That's because it is broken... I just updated to the proper one. I updated to your newer version but no luck either. If you are playing with it with lookahead, careful because it can do well even if base optimizer is crap based on what I have seen.
These are a few of the results of the current version
lr: 0.001; eff_lr: 0.001; size: 128; alpha: 0.99; mom: 0.9; eps: 1e-06
epoch train_loss valid_loss accuracy top_k_accuracy time
0 2.117527 2.215535 0.244000 0.736000 02:48
1 1.949242 2.063208 0.320000 0.830000 02:12
2 1.745239 1.941830 0.356000 0.874000 02:12
3 1.559608 1.542116 0.524000 0.938000 02:08
4 1.452923 1.492080 0.562000 0.940000 02:10
epoch train_loss valid_loss accuracy top_k_accuracy time
0 2.112155 2.267176 0.240000 0.724000 02:13
1 1.976614 2.071130 0.300000 0.778000 02:20
2 1.753372 1.787471 0.414000 0.880000 02:14
3 1.570473 1.574903 0.518000 0.930000 02:16
4 1.450896 1.522807 0.524000 0.940000 02:11
Great, I'll update the notebooks in repo soon.
Thanks a lot redknightlois. I've tried v4 of ralamb and it works much better! I got a new SOTA on a problem I'm working on, and it's much smoother.
@oguiza did you try it with lookahead of just as it is?
I have also tested with lookahed, but my results are a bit worse. Please, bear in mind this is with my own dataset.
Same here is as if lars is somehow interacting with lookahead. I have a few ideas that might be worth explore on that front
After spending hours locating the issue on my implementation, I found out that I was wrongly accumulating bias correction of momentum. I tend to forget sometimes the mutability of some Python objects...
It was also confirmed by @r1ckya
Anyway here is the fix. So far, my tests seem to point out that it's holding its own compared to the first revision from here!
Two differences between the last revision of @redknightlois & my implementations:
- Minor difference: on line 72-73, it is subtle but since you add epsilon to the 2nd momentum term before multiplying it by its bias correction (in
radam_step_size
), you obtain:
(sqrt(exp_avg_sq) + group['eps'])) / sqrt(bias_correction2)
I stuck to the paper and used the following instead:
sqrt(exp_avg_sq / bias_correction2) + group['eps']
- Major difference: according to the paper, the denominator of the
local_lr
should be the norm of:
expected_update = adaptive_momentum + group['weight_decay'] * p.data
where adaptive momentum equals r_t * exp_avg_hat / (sqrt(exp_avg_sq_hat) + eps)
if sma >4
and equals exp_avg_hat
otherwise.
But according to line 70 and the following ones, you actually take the norm of:
p.data - group['lr'] * expected_update
And in your case, when sma > 4, your adaptive momentum is actually r_t * exp_avg_hat / (sqrt(exp_avg_sq_hat) + eps / sqrt(bias_correction2))
Interestingly, your revision is performing quite well, so I guess you somehow did a finding of your own! It rescales the update by the norm of the expected updated params (without LARS):
local_lr = phi(norm(p)) / norm(expected_p)
where expected_p = p - group['lr'] * expected_update
instead of rescaling the update by its own norm:
local_lr = phi(norm(p)) / norm(expected_update)
I agree with @frgfm on, v3 was very similar to other LAMB or Relamb implementations I've seen and it was looking more or less like what papers describe, but v4 is very different from that (major difference pointed out by @frgfm).
I am testing it now and got better results, the problem was that we were doing it wrong. Real wrong... will update in a few minutes
I woundering what was wrong with v3 in your opinion and how did you come up with this different trust_ratio calculation rule.
I am still waiting to get some results on the 20 epoch run, but the 5 epoch runs on Ralars, Ralamb (v4), Ralamb (v1), RAdam and Lookahead show the following:
- Lookahead and RAdam are on top.
- Ralamb v4
- Ralars and Ralamb v1 are comparable
This is what we know so far:
- V1 was flawed (which somehow got good results) and the AFAIK (@mgrankin should comment on that) is the results on the main page.
- V2 and V3, they are even worse... These guys do not converge at all when running on isolation (which makes sense because we are not calculating the proper complete step before taking the actual step forward).
- V4 those differences are completely unintended, they are bugs. However, if what @frgfm is saying is right; we probably should dig there because there may be something we don't know yet about the behavior.
- I was not able to locally reproduce the results that @mgrankin repo has on the readme. (weird I know)
I have been working on a variant of Lookahead because it is as if LARS style optimizers do not play along with Lookahead (reason why Ranger beat all the others). But I am stuck, I need someone with a stronger math background to figure out the formula.
The idea is to incorporate the basic idea of LARS directly into Lookahead. Put it this way, you want to probably overshoot if the gradient has generally the same direction in all intermediate steps... therefore if you take t0
, ..., tk/2
, ... tk
the vector from to->tk/2
and t0->tk
should have a cosine similarity of 1... in that case you are probably looking into a deep dive kind of landscape... however if that cosine similarity is going to 0, we are probably in exploration mode and the buddy should stay a bit away in case the explorer trips into a hole (local minima).
If cosine similarity is 1, there is no harm into overshooting (trust ratio > 1) even though the buddy is going to be updated proportional to alpha
in the fast direction. While lookahead today updates the slow but doesn't change the fast gradient, I am saying that from the point of view of the buddy, we could push the fast gradient in overshooting mode, because he believes it is safe to do so from his point of view. And if it fails, you just wasted k
batches.
What do you think?
Interesting thought @redknightlois
If I get your idea correctly, you want to rescale the interpolation/update of slow weights in Lookahead based on a distance metric between the gradients of intermediate fast steps? (and suggesting the distance metric should be cosine similarity)
And you remove LARS from the actual base optimizer?
I can draft a few options on the math of this and implement it, but I'll probably need something to run the training as I believe I don't have the hardware to meet the same training conditions as you guys!
For the gradient, I actually can picture cases where cosine similarity of gradients wouldn't help (if that's what you meant). Since it's a wrapper, we don't control the way the base optimizer performs the fast steps. So imagine that we have an optimizer giving a lot of importance to momentum or other inertia factor and that, the gradient of fast weights at each step is very similar (cosine dist).
You may have the projection into a 2D scatter plot of fast weights that is actually a curve rather than a line (optimizer magic). In this case, we would overshoot further than the last fast weights but in a direction that would not fit the curve. I would have to test the hypothesis (perhaps it doesn't occur that often, but my intuition is that optimizers, that have updates where gradient is not the only main player, might not work well with this approach)
While implementing Lookahead, I actually had a somehow similar idea @redknightlois, but more of a model forward/backward efficiency perspective.
I didn't consider overshooting, but selecting alpha automatically based on the variance of the fast updates:
let s
be the slow weights, and f1
, ...., fk
be the fast weights
- I evaluate the variance of
[f2 - f1, ..., fk - f(k-1)]
which would characterize the consistency of the fast updates' direction. - I would normalize it / squeeze it into [0, 1] and use
1 - squeezed_variance
as a synchronization rate (alpha)
The issue I had, is that memory-wise, I would have to store k-1
fast weights into memory to perform this. So, hopefully, less computation required, but higher memory usage (linearly with the synchronization period k
). I'll check if I can avoid that memory overhead and try to implement this.
I re-run notebooks with latest Ralamb and RangerLars and updated the main page with the results.
And you remove LARS from the actual base optimizer?
Yes, you take out LARS from the base optimizer.
If I get your idea correctly, you want to rescale the interpolation/update of slow weights in Lookahead based on a distance metric between the gradients of intermediate fast steps? (and suggesting the distance metric should be cosine similarity)
Yes, I suggested cosine similarity because you can picture it in your head in 2 dimensions. As you mentioned when the 'curve' is an actual curve, the similarity of the vectors is bad, therefore you have to use the normal schedule...
In 2D this is what I have in mind.
I didn't consider overshooting, but selecting alpha automatically based on the variance of the fast updates:
I did a LARS style version of Lookahead which uses the 'trust_ratio' between the norm of the fast and slow weights... and at 5 epochs you don't see such a noticeable change... but havent have the time to run it further OR with Annealing schedule. Which looks like is making a lot of difference based on the results just published by @mgrankin. Quick question, those numbers already include the change to use Mish activation?
EDIT: Quick caveat... when I am saying you can overshoot, what I mean is that you update the fast weights to overshooting and the slow weights are modified in such a way that you do not move into overshooting range.
Quick question, those numbers already include the change to use Mish activation?
No, I haven't looked into Mish yet.
@redknightlois. Hi, Thanks for the implementation. I have a question, how to save the Ralamb optimizer state_dict. There is no function for that. There is no load_state_dict function as well. Thanks
No, this was a prototype that I knocked up in a few hours time. Feel free to add those and I will update it.
@mgrankin Have you seen in the 80 annealing schedule such a huge back and forth in loss?