Last active
July 17, 2023 23:22
-
-
Save Ryu1845/7051dcab7415661542925e0ac4a85935 to your computer and use it in GitHub Desktop.
CAME: Confidence-guided Adaptive Memory Efficient Optimization from the official repo (https://github.com/huawei-noah/Pretrained-Language-Model/blob/master/CAME/came.py)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import torch.optim | |
class CAME(torch.optim.Optimizer): | |
"""Implements CAME algorithm. | |
This implementation is based on: | |
`CAME: Confidence-guided Adaptive Memory Efficient Optimization` | |
Args: | |
params (iterable): iterable of parameters to optimize or dicts defining | |
parameter groups | |
lr (float, optional): external learning rate (default: None) | |
eps (tuple[float, float]): regularization constants for square gradient | |
and instability respectively (default: (1e-30, 1e-16)) | |
clip_threshold (float): threshold of root-mean-square of | |
final gradient update (default: 1.0) | |
betas (tuple[float, float, float]): coefficient used for computing running averages of | |
update, square gradient and instability (default: (0.9, 0.999, 0.9999))) | |
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | |
""" | |
def __init__( | |
self, | |
params, | |
lr=None, | |
eps=(1e-30, 1e-16), | |
clip_threshold=1.0, | |
betas=(0.9, 0.999, 0.9999), | |
weight_decay=0.0, | |
): | |
assert lr > 0. | |
assert all([0. <= beta <= 1. for beta in betas]) | |
defaults = dict( | |
lr=lr, | |
eps=eps, | |
clip_threshold=clip_threshold, | |
betas=betas, | |
weight_decay=weight_decay, | |
) | |
super(CAME, self).__init__(params, defaults) | |
@property | |
def supports_memory_efficient_fp16(self): | |
return True | |
@property | |
def supports_flat_params(self): | |
return False | |
def _get_options(self, param_shape): | |
factored = len(param_shape) >= 2 | |
return factored | |
def _rms(self, tensor): | |
return tensor.norm(2) / (tensor.numel() ** 0.5) | |
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): | |
r_factor = ( | |
(exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) | |
.rsqrt_() | |
.unsqueeze(-1) | |
) | |
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() | |
return torch.mul(r_factor, c_factor) | |
def step(self, closure=None): | |
"""Performs a single optimization step. | |
Args: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
for p in group["params"]: | |
if p.grad is None: | |
continue | |
grad = p.grad.data | |
if grad.dtype in {torch.float16, torch.bfloat16}: | |
grad = grad.float() | |
if grad.is_sparse: | |
raise RuntimeError("CAME does not support sparse gradients.") | |
state = self.state[p] | |
grad_shape = grad.shape | |
factored = self._get_options(grad_shape) | |
# State Initialization | |
if len(state) == 0: | |
state["step"] = 0 | |
state["exp_avg"] = torch.zeros_like(grad) | |
if factored: | |
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).type_as(grad) | |
state["exp_avg_sq_col"] = torch.zeros( | |
grad_shape[:-2] + grad_shape[-1:] | |
).type_as(grad) | |
state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1]).type_as(grad) | |
state["exp_avg_res_col"] = torch.zeros( | |
grad_shape[:-2] + grad_shape[-1:] | |
).type_as(grad) | |
else: | |
state["exp_avg_sq"] = torch.zeros_like(grad) | |
state["RMS"] = 0 | |
state["step"] += 1 | |
state["RMS"] = self._rms(p.data) | |
update = (grad**2) + group["eps"][0] | |
if factored: | |
exp_avg_sq_row = state["exp_avg_sq_row"] | |
exp_avg_sq_col = state["exp_avg_sq_col"] | |
exp_avg_sq_row.mul_(group["betas"][1]).add_( | |
update.mean(dim=-1), alpha=1.0 - group["betas"][1] | |
) | |
exp_avg_sq_col.mul_(group["betas"][1]).add_( | |
update.mean(dim=-2), alpha=1.0 - group["betas"][1] | |
) | |
# Approximation of exponential moving average of square of gradient | |
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) | |
update.mul_(grad) | |
else: | |
exp_avg_sq = state["exp_avg_sq"] | |
exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1]) | |
update = exp_avg_sq.rsqrt().mul_(grad) | |
update.div_( | |
(self._rms(update) / group["clip_threshold"]).clamp_(min=1.0) | |
) | |
exp_avg = state["exp_avg"] | |
exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0]) | |
# Confidence-guided strategy | |
# Calculation of instability | |
res = (update - exp_avg)**2 + group["eps"][1] | |
if factored: | |
exp_avg_res_row = state["exp_avg_res_row"] | |
exp_avg_res_col = state["exp_avg_res_col"] | |
exp_avg_res_row.mul_(group["betas"][2]).add_( | |
res.mean(dim=-1), alpha=1.0 - group["betas"][2] | |
) | |
exp_avg_res_col.mul_(group["betas"][2]).add_( | |
res.mean(dim=-2), alpha=1.0 - group["betas"][2] | |
) | |
# Approximation of exponential moving average of instability | |
res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col) | |
update = res_approx.mul_(exp_avg) | |
else: | |
update = exp_avg | |
if group["weight_decay"] != 0: | |
p.data.add_( | |
p.data, alpha=-group["weight_decay"] * group["lr"] | |
) | |
update.mul_(group["lr"]) | |
p.data.add_(-update) | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment