Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active August 3, 2025 21:05
Show Gist options
  • Save crowsonkb/26002348a1514fcee034a78322dcf13e to your computer and use it in GitHub Desktop.
Save crowsonkb/26002348a1514fcee034a78322dcf13e to your computer and use it in GitHub Desktop.
Implementation of the syre weight decay algorithm from "Remove Symmetries to Control Model Expressivity and Improve Optimization" (https://arxiv.org/abs/2408.15495).
"""Implementation of the syre weight decay algorithm from "Remove Symmetries to Control Model
Expressivity and Improve Optimization" (https://arxiv.org/abs/2408.15495)."""
import math
import torch
from torch import optim
import triton
import triton.language as tl
@triton.jit
def _syre_wd_kernel(ptr, n_elements, gamma, std, eps, seed1, seed2, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
theta = tl.load(ptr + offsets, mask=mask)
theta_0 = tl.randn(seed1, offsets) * std
d = tl.rand(seed2, offsets) * (eps * 2) + (1 - eps)
output = theta - gamma * d * (theta - theta_0)
tl.store(ptr + offsets, output, mask=mask)
def syre_wd_(
theta: torch.Tensor, gamma: float, std: float, eps: float, seed1: int, seed2: int
) -> torch.Tensor:
"""Apply the syre weight decay algorithm in-place to the input tensor.
Args:
theta (torch.Tensor): The tensor to apply weight decay to.
gamma (float): The decay factor, which should be `weight_decay * lr`.
std (float): The standard deviation for theta_0. This should be around 0.01 / sqrt(fan_in).
eps (float): The epsilon value for D. This should be around 0.01.
seed1 (int): The random seed for theta_0.
seed2 (int): The random seed for D.
Returns:
torch.Tensor: The tensor after applying the syre weight decay.
"""
if not theta.is_contiguous():
raise ValueError("theta must be contiguous")
n_elements = theta.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_syre_wd_kernel[grid](theta, n_elements, gamma, std, eps, seed1, seed2, BLOCK_SIZE=1024)
return theta
class AdamSyre(optim.Optimizer):
r"""Implements Adam algorithm with decoupled syre weight decay.
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. The syre weight
decay algorithm was proposed in `Remove Symmetries to Control Model Expressivity and Improve
Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
syre_std (float, optional): standard deviation for theta_0 in syre weight decay
(default: 0.01 / sqrt(768), where 768 is a common hidden dimension)
syre_eps (float, optional): epsilon value for D in syre weight decay
(default: 0.01)
syre_seed (int, optional): seed for random number generation in syre weight decay
(default: 0)
maximize (bool, optional): if True, optimizes the parameters in the direction of
increasing the objective (default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _Remove Symmetries to Control Model Expressivity and Improve Optimization:
https://arxiv.org/abs/2408.15495
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
syre_std=0.01 / math.sqrt(768),
syre_eps=0.01,
syre_seed=0,
*,
maximize=False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= syre_std:
raise ValueError("Invalid syre_std value: {}".format(syre_std))
if not 0.0 <= syre_eps:
raise ValueError("Invalid syre_eps value: {}".format(syre_eps))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
syre_std=syre_std,
syre_eps=syre_eps,
syre_seed=syre_seed,
maximize=maximize,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
gen = torch.Generator()
gen.manual_seed(group["syre_seed"])
for p in group["params"]:
seed1, seed2 = torch.empty(2, dtype=torch.int64).random_(generator=gen).tolist()
if p.grad is None:
continue
# Perform syre weight decay
syre_wd_(
p,
group["weight_decay"] * group["lr"],
group["syre_std"],
group["syre_eps"],
seed1,
seed2,
)
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError("AdamSyre does not support sparse gradients")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
step_size = group["lr"] / bias_correction1
p.addcdiv_(exp_avg, denom, value=step_size if group["maximize"] else -step_size)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment