Last active
October 17, 2023 21:40
-
-
Save mitchellnw/d42e22a0b9ec02ceaf4f7b4457f51423 to your computer and use it in GitHub Desktop.
This is the unfused version of StableAdamW. It is slower than the fused version (coming). Beta version.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
# This is the unfused version of StableAdamW. It is slower than the fused version (coming). | |
class StableAdamWUnfused(torch.optim.Optimizer): | |
def __init__(self, params, lr=0.002, weight_decay=0.2, betas=(0.9, 0.99), eps=1e-6, clip_thresh=1., precision='amp_bfloat16', custom_scalar=65536): | |
beta1, beta2 = betas[0], betas[1] | |
defaults = dict(lr=lr, weight_decay=weight_decay, beta1=beta1, beta2=beta2) | |
super(StableAdamWUnfused, self).__init__(params, defaults) | |
self.eps=eps | |
self.d = clip_thresh | |
# Set precision to "custom_fp16" if you want to use a fixed loss scalar, custom_scalar, which is divided out in the update step. | |
# If you do this, call (custom_scalar * loss).backward() instead of loss.backward(). | |
self.precision = precision | |
self.custom_scaler = custom_scalar | |
for group in self.param_groups: | |
group['step'] = 1. | |
print('Using StableAdamWUnfused-v1') | |
def __setstate__(self, state): | |
super(StableAdamWUnfused, self).__setstate__(state) | |
def step(self, closure=None): | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
lr = group['lr'] | |
weight_decay = group['weight_decay'] | |
beta1 = group['beta1'] | |
beta2 = group['beta2'] | |
step = group['step'] | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
theta=p.data | |
param_state = self.state[p] | |
if self.precision == 'custom_fp16': | |
g = p.grad.data / self.custom_scaler | |
if torch.any(torch.isnan(g) | torch.isinf(g)): | |
continue | |
else: | |
g = p.grad.data | |
if 'exp_avg' not in param_state: | |
v = param_state['exp_avg'] = torch.zeros_like(theta) | |
u = param_state['exp_avg_sq'] = torch.zeros_like(theta) | |
else: | |
v = param_state['exp_avg'] | |
u = param_state['exp_avg_sq'] | |
beta1hat = beta1 * (1 - beta1**(step - 1)) / (1 - beta1**step) | |
beta2hat = beta2 * (1 - beta2**(step - 1)) / (1 - beta2**step) | |
v = v.mul_(beta1hat).add_(g, alpha=1.0-beta1hat) | |
u = u.mul_(beta2hat).addcmul_(g,g,value=1.0-beta2hat) | |
denominator = u.sqrt().add_(self.eps) | |
# StableAdamW = AdamW + update clipping (https://arxiv.org/abs/1804.04235) applied tensor-wise. | |
rms = torch.div( | |
g.pow(2), | |
torch.maximum(u, (self.eps ** 2) * torch.ones_like(u)) | |
).mean().sqrt().item() | |
new_lr = lr * (1. / max(1., rms / self.d )) | |
theta = theta.mul_(1.0-new_lr*weight_decay).addcdiv_( | |
v, | |
denominator, | |
value=-new_lr | |
) | |
# save current params | |
param_state['exp_avg'] = v | |
param_state['exp_avg_sq'] = u | |
group['step'] = step + 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment