Last active
November 27, 2024 03:53
-
-
Save pszemraj/8cb3cc8d236d9ad1b025a6973af171ee to your computer and use it in GitHub Desktop.
Implements Cautious AdamW optimizer by subclassing AdamW https://github.com/kyleliang919/C-Optim
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from torch.optim.adamw import AdamW | |
class CautiousAdamW(AdamW): | |
""" | |
Implements Cautious AdamW optimizer by subclassing AdamW. | |
All hyperparameters remain identical to AdamW. | |
The only change is applying the cautious mask to updates. | |
Arguments match torch.optim.AdamW exactly. | |
Additional args: | |
mask_scale (bool): If True, scale learning rate by proportion of updates being applied | |
""" | |
def __init__( | |
self, | |
params, | |
lr=1e-3, | |
betas=(0.9, 0.999), | |
eps=1e-8, | |
weight_decay=1e-2, | |
amsgrad=False, | |
*, | |
maximize=False, | |
foreach=None, | |
capturable=False, | |
differentiable=False, | |
mask_scale=True | |
): | |
super().__init__( | |
params, | |
lr=lr, | |
betas=betas, | |
eps=eps, | |
weight_decay=weight_decay, | |
amsgrad=amsgrad, | |
maximize=maximize, | |
foreach=foreach, | |
capturable=capturable, | |
differentiable=differentiable, | |
) | |
# Add mask_scale to param_groups | |
for group in self.param_groups: | |
group["mask_scale"] = mask_scale | |
@torch.no_grad() | |
def step(self, closure=None): | |
""" | |
Performs a single optimization step. | |
Modified from AdamW to include cautious masking. | |
""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
params_with_grad = [] | |
grads = [] | |
exp_avgs = [] | |
exp_avg_sqs = [] | |
max_exp_avg_sqs = [] | |
state_steps = [] | |
beta1, beta2 = group["betas"] | |
mask_scale = group["mask_scale"] | |
eps = group["eps"] | |
lr = group["lr"] | |
weight_decay = group["weight_decay"] | |
maximize = group["maximize"] | |
for p in group["params"]: | |
if p.grad is None: | |
continue | |
if p.grad.is_sparse: | |
raise RuntimeError( | |
"CautiousAdamW does not support sparse gradients" | |
) | |
params_with_grad.append(p) | |
grad = p.grad | |
if maximize: | |
grad = -grad | |
grads.append(grad) | |
state = self.state[p] | |
# State initialization | |
if len(state) == 0: | |
state["step"] = 0 | |
state["exp_avg"] = torch.zeros_like( | |
p, memory_format=torch.preserve_format | |
) | |
state["exp_avg_sq"] = torch.zeros_like( | |
p, memory_format=torch.preserve_format | |
) | |
if group["amsgrad"]: | |
state["max_exp_avg_sq"] = torch.zeros_like( | |
p, memory_format=torch.preserve_format | |
) | |
exp_avgs.append(state["exp_avg"]) | |
exp_avg_sqs.append(state["exp_avg_sq"]) | |
if group["amsgrad"]: | |
max_exp_avg_sqs.append(state["max_exp_avg_sq"]) | |
state["step"] += 1 | |
state_steps.append(state["step"]) | |
# Standard AdamW update computations | |
for i, param in enumerate(params_with_grad): | |
grad = grads[i] | |
exp_avg = exp_avgs[i] | |
exp_avg_sq = exp_avg_sqs[i] | |
step = state_steps[i] | |
# Apply weight decay directly to the gradient | |
if weight_decay != 0: | |
grad = grad.add(param, alpha=weight_decay) | |
# Update biased first and second moment estimates | |
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) | |
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) | |
# Bias corrections | |
bias_correction1 = 1 - beta1**step | |
bias_correction2 = 1 - beta2**step | |
# Compute step size | |
step_size = lr / bias_correction1 | |
# Compute denominator | |
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) | |
# Compute update | |
update = exp_avg / denom | |
# Cautious masking | |
mask = (update * grad > 0).to(dtype=update.dtype) | |
if mask_scale: | |
# Use a small epsilon to prevent division by zero | |
mask_epsilon = 1e-8 | |
mask_mean = mask.mean().add_(mask_epsilon) | |
mask = mask / mask_mean | |
# Apply mask to update | |
update = update * mask | |
# Apply the update | |
param.add_(update, alpha=-step_size) | |
return loss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from torch.optim import Optimizer | |
from typing import List, Optional | |
def _single_tensor_cautious_adamw( | |
params: List[torch.Tensor], | |
grads: List[torch.Tensor], | |
exp_avgs: List[torch.Tensor], | |
exp_avg_sqs: List[torch.Tensor], | |
max_exp_avg_sqs: List[torch.Tensor], | |
state_steps: List[int], | |
*, | |
beta1: float, | |
beta2: float, | |
lr: float, | |
weight_decay: float, | |
eps: float, | |
maximize: bool, | |
mask_scale: bool, | |
amsgrad: bool, | |
capturable: bool, | |
differentiable: bool, | |
): | |
for i, param in enumerate(params): | |
grad = grads[i] if not maximize else -grads[i] | |
exp_avg = exp_avgs[i] | |
exp_avg_sq = exp_avg_sqs[i] | |
step = state_steps[i] | |
if weight_decay != 0: | |
grad = grad.add(param, alpha=weight_decay) | |
# Update moving averages | |
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) | |
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) | |
if capturable: | |
step = torch.scalar_tensor(step, dtype=param.dtype, device=param.device) | |
bias_correction1 = 1 - beta1 ** step | |
bias_correction2 = 1 - beta2 ** step | |
else: | |
bias_correction1 = 1 - beta1 ** step | |
bias_correction2 = 1 - beta2 ** step | |
if amsgrad: | |
max_exp_avg_sq = max_exp_avg_sqs[i] | |
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) | |
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) | |
else: | |
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) | |
step_size = lr / bias_correction1 | |
update = exp_avg / denom | |
# Compute mask | |
mask = torch.sign(update * grad).clamp(min=0) | |
if mask_scale: | |
mask_mean = mask.mean().add_(eps) | |
mask = mask / mask_mean | |
# Apply update | |
param.add_(mask * update, alpha=-step_size) | |
def _multi_tensor_cautious_adamw( | |
params: List[torch.Tensor], | |
grads: List[torch.Tensor], | |
exp_avgs: List[torch.Tensor], | |
exp_avg_sqs: List[torch.Tensor], | |
max_exp_avg_sqs: List[torch.Tensor], | |
state_steps: List[int], | |
*, | |
beta1: float, | |
beta2: float, | |
lr: float, | |
weight_decay: float, | |
eps: float, | |
maximize: bool, | |
mask_scale: bool, | |
amsgrad: bool, | |
capturable: bool, | |
differentiable: bool, | |
): | |
if len(params) == 0: | |
return | |
if maximize: | |
grads = [(-grad) for grad in grads] | |
# Weight decay | |
if weight_decay != 0: | |
grads = [grad.add(param, alpha=weight_decay) for grad, param in zip(grads, params)] | |
# Update moving averages | |
exp_avgs = [exp_avg.mul(beta1).add(grad, alpha=1 - beta1) for exp_avg, grad in zip(exp_avgs, grads)] | |
exp_avg_sqs = [exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2) for exp_avg_sq, grad in zip(exp_avg_sqs, grads)] | |
bias_correction1 = [1 - beta1 ** step for step in state_steps] | |
bias_correction2 = [1 - beta2 ** step for step in state_steps] | |
if amsgrad: | |
max_exp_avg_sqs = [torch.maximum(max_exp_avg_sq, exp_avg_sq) for max_exp_avg_sq, exp_avg_sq in zip(max_exp_avg_sqs, exp_avg_sqs)] | |
denom = [(max_exp_avg_sq.sqrt() / math.sqrt(bc2)).add(eps) for max_exp_avg_sq, bc2 in zip(max_exp_avg_sqs, bias_correction2)] | |
else: | |
denom = [(exp_avg_sq.sqrt() / math.sqrt(bc2)).add(eps) for exp_avg_sq, bc2 in zip(exp_avg_sqs, bias_correction2)] | |
step_sizes = [lr / bc1 for bc1 in bias_correction1] | |
updates = [exp_avg / d for exp_avg, d in zip(exp_avgs, denom)] | |
# Compute masks | |
masks = [torch.sign(update * grad).clamp(min=0) for update, grad in zip(updates, grads)] | |
if mask_scale: | |
mask_means = [mask.mean().add(eps) for mask in masks] | |
masks = [mask / mask_mean for mask, mask_mean in zip(masks, mask_means)] | |
# Apply updates | |
params = [param.add(mask * update, alpha=-step_size) for param, mask, update, step_size in zip(params, masks, updates, step_sizes)] | |
class OptimizedCautiousAdamW(Optimizer): | |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | |
weight_decay=1e-2, amsgrad=False, *, maximize=False, | |
foreach: Optional[bool] = None, capturable: bool = False, | |
differentiable: bool = False, mask_scale: bool = True): | |
if lr <= 0.0: | |
raise ValueError(f"Invalid learning rate: {lr}") | |
if not 0.0 <= eps: | |
raise ValueError(f"Invalid epsilon value: {eps}") | |
if not 0.0 <= betas[0] < 1.0: | |
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") | |
if not 0.0 <= betas[1] < 1.0: | |
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") | |
defaults = dict(lr=lr, betas=betas, eps=eps, | |
weight_decay=weight_decay, amsgrad=amsgrad, | |
maximize=maximize, foreach=foreach, | |
capturable=capturable, differentiable=differentiable, | |
mask_scale=mask_scale) | |
super().__init__(params, defaults) | |
@torch.no_grad() | |
def step(self, closure=None): | |
"""Performs a single optimization step.""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
params_with_grad = [] | |
grads = [] | |
exp_avgs = [] | |
exp_avg_sqs = [] | |
max_exp_avg_sqs = [] | |
state_steps = [] | |
beta1, beta2 = group['betas'] | |
lr = group['lr'] | |
weight_decay = group['weight_decay'] | |
eps = group['eps'] | |
maximize = group['maximize'] | |
amsgrad = group['amsgrad'] | |
mask_scale = group['mask_scale'] | |
capturable = group['capturable'] | |
differentiable = group['differentiable'] | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
if p.grad.is_sparse: | |
raise RuntimeError('CautiousAdamW does not support sparse gradients') | |
params_with_grad.append(p) | |
grad = p.grad | |
if grad.is_sparse: | |
raise RuntimeError('Adam does not support sparse gradients') | |
grads.append(grad) | |
state = self.state[p] | |
# State initialization | |
if len(state) == 0: | |
state['step'] = 0 | |
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
if amsgrad: | |
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
exp_avgs.append(state['exp_avg']) | |
exp_avg_sqs.append(state['exp_avg_sq']) | |
if amsgrad: | |
max_exp_avg_sqs.append(state['max_exp_avg_sq']) | |
state['step'] += 1 | |
state_steps.append(state['step']) | |
# Choose the appropriate update function | |
func = _single_tensor_cautious_adamw | |
func( | |
params_with_grad, | |
grads, | |
exp_avgs, | |
exp_avg_sqs, | |
max_exp_avg_sqs, | |
state_steps, | |
beta1=beta1, | |
beta2=beta2, | |
lr=lr, | |
weight_decay=weight_decay, | |
eps=eps, | |
maximize=maximize, | |
mask_scale=mask_scale, | |
amsgrad=amsgrad, | |
capturable=capturable, | |
differentiable=differentiable, | |
) | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment