Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active November 27, 2024 03:53
Show Gist options
  • Save pszemraj/8cb3cc8d236d9ad1b025a6973af171ee to your computer and use it in GitHub Desktop.
Save pszemraj/8cb3cc8d236d9ad1b025a6973af171ee to your computer and use it in GitHub Desktop.
Implements Cautious AdamW optimizer by subclassing AdamW https://github.com/kyleliang919/C-Optim
import math
import torch
from torch.optim.adamw import AdamW
class CautiousAdamW(AdamW):
"""
Implements Cautious AdamW optimizer by subclassing AdamW.
All hyperparameters remain identical to AdamW.
The only change is applying the cautious mask to updates.
Arguments match torch.optim.AdamW exactly.
Additional args:
mask_scale (bool): If True, scale learning rate by proportion of updates being applied
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
*,
maximize=False,
foreach=None,
capturable=False,
differentiable=False,
mask_scale=True
):
super().__init__(
params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
maximize=maximize,
foreach=foreach,
capturable=capturable,
differentiable=differentiable,
)
# Add mask_scale to param_groups
for group in self.param_groups:
group["mask_scale"] = mask_scale
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step.
Modified from AdamW to include cautious masking.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps = []
beta1, beta2 = group["betas"]
mask_scale = group["mask_scale"]
eps = group["eps"]
lr = group["lr"]
weight_decay = group["weight_decay"]
maximize = group["maximize"]
for p in group["params"]:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError(
"CautiousAdamW does not support sparse gradients"
)
params_with_grad.append(p)
grad = p.grad
if maximize:
grad = -grad
grads.append(grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if group["amsgrad"]:
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if group["amsgrad"]:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state["step"] += 1
state_steps.append(state["step"])
# Standard AdamW update computations
for i, param in enumerate(params_with_grad):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
# Apply weight decay directly to the gradient
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# Update biased first and second moment estimates
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Bias corrections
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
# Compute step size
step_size = lr / bias_correction1
# Compute denominator
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
# Compute update
update = exp_avg / denom
# Cautious masking
mask = (update * grad > 0).to(dtype=update.dtype)
if mask_scale:
# Use a small epsilon to prevent division by zero
mask_epsilon = 1e-8
mask_mean = mask.mean().add_(mask_epsilon)
mask = mask / mask_mean
# Apply mask to update
update = update * mask
# Apply the update
param.add_(update, alpha=-step_size)
return loss
import math
import torch
from torch.optim import Optimizer
from typing import List, Optional
def _single_tensor_cautious_adamw(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
exp_avg_sqs: List[torch.Tensor],
max_exp_avg_sqs: List[torch.Tensor],
state_steps: List[int],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
mask_scale: bool,
amsgrad: bool,
capturable: bool,
differentiable: bool,
):
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# Update moving averages
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if capturable:
step = torch.scalar_tensor(step, dtype=param.dtype, device=param.device)
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
else:
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
if amsgrad:
max_exp_avg_sq = max_exp_avg_sqs[i]
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
update = exp_avg / denom
# Compute mask
mask = torch.sign(update * grad).clamp(min=0)
if mask_scale:
mask_mean = mask.mean().add_(eps)
mask = mask / mask_mean
# Apply update
param.add_(mask * update, alpha=-step_size)
def _multi_tensor_cautious_adamw(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
exp_avg_sqs: List[torch.Tensor],
max_exp_avg_sqs: List[torch.Tensor],
state_steps: List[int],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
mask_scale: bool,
amsgrad: bool,
capturable: bool,
differentiable: bool,
):
if len(params) == 0:
return
if maximize:
grads = [(-grad) for grad in grads]
# Weight decay
if weight_decay != 0:
grads = [grad.add(param, alpha=weight_decay) for grad, param in zip(grads, params)]
# Update moving averages
exp_avgs = [exp_avg.mul(beta1).add(grad, alpha=1 - beta1) for exp_avg, grad in zip(exp_avgs, grads)]
exp_avg_sqs = [exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2) for exp_avg_sq, grad in zip(exp_avg_sqs, grads)]
bias_correction1 = [1 - beta1 ** step for step in state_steps]
bias_correction2 = [1 - beta2 ** step for step in state_steps]
if amsgrad:
max_exp_avg_sqs = [torch.maximum(max_exp_avg_sq, exp_avg_sq) for max_exp_avg_sq, exp_avg_sq in zip(max_exp_avg_sqs, exp_avg_sqs)]
denom = [(max_exp_avg_sq.sqrt() / math.sqrt(bc2)).add(eps) for max_exp_avg_sq, bc2 in zip(max_exp_avg_sqs, bias_correction2)]
else:
denom = [(exp_avg_sq.sqrt() / math.sqrt(bc2)).add(eps) for exp_avg_sq, bc2 in zip(exp_avg_sqs, bias_correction2)]
step_sizes = [lr / bc1 for bc1 in bias_correction1]
updates = [exp_avg / d for exp_avg, d in zip(exp_avgs, denom)]
# Compute masks
masks = [torch.sign(update * grad).clamp(min=0) for update, grad in zip(updates, grads)]
if mask_scale:
mask_means = [mask.mean().add(eps) for mask in masks]
masks = [mask / mask_mean for mask, mask_mean in zip(masks, mask_means)]
# Apply updates
params = [param.add(mask * update, alpha=-step_size) for param, mask, update, step_size in zip(params, masks, updates, step_sizes)]
class OptimizedCautiousAdamW(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False, *, maximize=False,
foreach: Optional[bool] = None, capturable: bool = False,
differentiable: bool = False, mask_scale: bool = True):
if lr <= 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad,
maximize=maximize, foreach=foreach,
capturable=capturable, differentiable=differentiable,
mask_scale=mask_scale)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps = []
beta1, beta2 = group['betas']
lr = group['lr']
weight_decay = group['weight_decay']
eps = group['eps']
maximize = group['maximize']
amsgrad = group['amsgrad']
mask_scale = group['mask_scale']
capturable = group['capturable']
differentiable = group['differentiable']
for p in group['params']:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError('CautiousAdamW does not support sparse gradients')
params_with_grad.append(p)
grad = p.grad
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients')
grads.append(grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
state['step'] += 1
state_steps.append(state['step'])
# Choose the appropriate update function
func = _single_tensor_cautious_adamw
func(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
mask_scale=mask_scale,
amsgrad=amsgrad,
capturable=capturable,
differentiable=differentiable,
)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment