Last active
August 3, 2025 21:05
-
-
Save crowsonkb/26002348a1514fcee034a78322dcf13e to your computer and use it in GitHub Desktop.
Implementation of the syre weight decay algorithm from "Remove Symmetries to Control Model Expressivity and Improve Optimization" (https://arxiv.org/abs/2408.15495).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Implementation of the syre weight decay algorithm from "Remove Symmetries to Control Model | |
Expressivity and Improve Optimization" (https://arxiv.org/abs/2408.15495).""" | |
import math | |
import torch | |
from torch import optim | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def _syre_wd_kernel(ptr, n_elements, gamma, std, eps, seed1, seed2, BLOCK_SIZE: tl.constexpr): | |
pid = tl.program_id(axis=0) | |
block_start = pid * BLOCK_SIZE | |
offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
mask = offsets < n_elements | |
theta = tl.load(ptr + offsets, mask=mask) | |
theta_0 = tl.randn(seed1, offsets) * std | |
d = tl.rand(seed2, offsets) * (eps * 2) + (1 - eps) | |
output = theta - gamma * d * (theta - theta_0) | |
tl.store(ptr + offsets, output, mask=mask) | |
def syre_wd_( | |
theta: torch.Tensor, gamma: float, std: float, eps: float, seed1: int, seed2: int | |
) -> torch.Tensor: | |
"""Apply the syre weight decay algorithm in-place to the input tensor. | |
Args: | |
theta (torch.Tensor): The tensor to apply weight decay to. | |
gamma (float): The decay factor, which should be `weight_decay * lr`. | |
std (float): The standard deviation for theta_0. This should be around 0.01 / sqrt(fan_in). | |
eps (float): The epsilon value for D. This should be around 0.01. | |
seed1 (int): The random seed for theta_0. | |
seed2 (int): The random seed for D. | |
Returns: | |
torch.Tensor: The tensor after applying the syre weight decay. | |
""" | |
if not theta.is_contiguous(): | |
raise ValueError("theta must be contiguous") | |
n_elements = theta.numel() | |
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | |
_syre_wd_kernel[grid](theta, n_elements, gamma, std, eps, seed1, seed2, BLOCK_SIZE=1024) | |
return theta | |
class AdamSyre(optim.Optimizer): | |
r"""Implements Adam algorithm with decoupled syre weight decay. | |
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. | |
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. The syre weight | |
decay algorithm was proposed in `Remove Symmetries to Control Model Expressivity and Improve | |
Optimization`_. | |
Arguments: | |
params (iterable): iterable of parameters to optimize or dicts defining | |
parameter groups | |
lr (float, optional): learning rate (default: 1e-3) | |
betas (Tuple[float, float], optional): coefficients used for computing | |
running averages of gradient and its square (default: (0.9, 0.999)) | |
eps (float, optional): term added to the denominator to improve | |
numerical stability (default: 1e-8) | |
weight_decay (float, optional): weight decay coefficient (default: 1e-2) | |
syre_std (float, optional): standard deviation for theta_0 in syre weight decay | |
(default: 0.01 / sqrt(768), where 768 is a common hidden dimension) | |
syre_eps (float, optional): epsilon value for D in syre weight decay | |
(default: 0.01) | |
syre_seed (int, optional): seed for random number generation in syre weight decay | |
(default: 0) | |
maximize (bool, optional): if True, optimizes the parameters in the direction of | |
increasing the objective (default: False) | |
.. _Adam\: A Method for Stochastic Optimization: | |
https://arxiv.org/abs/1412.6980 | |
.. _Decoupled Weight Decay Regularization: | |
https://arxiv.org/abs/1711.05101 | |
.. _Remove Symmetries to Control Model Expressivity and Improve Optimization: | |
https://arxiv.org/abs/2408.15495 | |
""" | |
def __init__( | |
self, | |
params, | |
lr=1e-3, | |
betas=(0.9, 0.999), | |
eps=1e-8, | |
weight_decay=1e-2, | |
syre_std=0.01 / math.sqrt(768), | |
syre_eps=0.01, | |
syre_seed=0, | |
*, | |
maximize=False, | |
): | |
if not 0.0 <= lr: | |
raise ValueError("Invalid learning rate: {}".format(lr)) | |
if not 0.0 <= eps: | |
raise ValueError("Invalid epsilon value: {}".format(eps)) | |
if not 0.0 <= betas[0] < 1.0: | |
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | |
if not 0.0 <= betas[1] < 1.0: | |
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | |
if not 0.0 <= weight_decay: | |
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | |
if not 0.0 <= syre_std: | |
raise ValueError("Invalid syre_std value: {}".format(syre_std)) | |
if not 0.0 <= syre_eps: | |
raise ValueError("Invalid syre_eps value: {}".format(syre_eps)) | |
defaults = dict( | |
lr=lr, | |
betas=betas, | |
eps=eps, | |
weight_decay=weight_decay, | |
syre_std=syre_std, | |
syre_eps=syre_eps, | |
syre_seed=syre_seed, | |
maximize=maximize, | |
) | |
super().__init__(params, defaults) | |
def __setstate__(self, state): | |
super().__setstate__(state) | |
@torch.no_grad() | |
def step(self, closure=None): | |
"""Performs a single optimization step. | |
Arguments: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
gen = torch.Generator() | |
gen.manual_seed(group["syre_seed"]) | |
for p in group["params"]: | |
seed1, seed2 = torch.empty(2, dtype=torch.int64).random_(generator=gen).tolist() | |
if p.grad is None: | |
continue | |
# Perform syre weight decay | |
syre_wd_( | |
p, | |
group["weight_decay"] * group["lr"], | |
group["syre_std"], | |
group["syre_eps"], | |
seed1, | |
seed2, | |
) | |
# Perform optimization step | |
grad = p.grad | |
if grad.is_sparse: | |
raise RuntimeError("AdamSyre does not support sparse gradients") | |
state = self.state[p] | |
# State initialization | |
if len(state) == 0: | |
state["step"] = 0 | |
# Exponential moving average of gradient values | |
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
# Exponential moving average of squared gradient values | |
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] | |
beta1, beta2 = group["betas"] | |
state["step"] += 1 | |
bias_correction1 = 1 - beta1 ** state["step"] | |
bias_correction2 = 1 - beta2 ** state["step"] | |
# Decay the first and second moment running average coefficient | |
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) | |
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) | |
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"]) | |
step_size = group["lr"] / bias_correction1 | |
p.addcdiv_(exp_avg, denom, value=step_size if group["maximize"] else -step_size) | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment