Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Created June 30, 2021 10:49
Show Gist options
  • Save sadimanna/f2a7235c910573f30f74541868707d11 to your computer and use it in GitHub Desktop.
Save sadimanna/f2a7235c910573f30f74541868707d11 to your computer and use it in GitHub Desktop.
from torch.optim.optimizer import Optimizer, required
import re
EETA_DEFAULT = 0.001
class LARS(Optimizer):
"""
Layer-wise Adaptive Rate Scaling for large batch training.
Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
"""
def __init__(
self,
params,
lr=required,
momentum=0.9,
use_nesterov=False,
weight_decay=0.0,
exclude_from_weight_decay=None,
exclude_from_layer_adaptation=None,
classic_momentum=True,
eeta=EETA_DEFAULT,
):
"""Constructs a LARSOptimizer.
Args:
lr: A `float` for learning rate.
momentum: A `float` for momentum.
use_nesterov: A 'Boolean' for whether to use nesterov momentum.
weight_decay: A `float` for weight decay.
exclude_from_weight_decay: A list of `string` for variable screening, if
any of the string appears in a variable's name, the variable will be
excluded for computing weight decay. For example, one could specify
the list like ['batch_normalization', 'bias'] to exclude BN and bias
from weight decay.
exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
for layer adaptation. If it is None, it will be defaulted the same as
exclude_from_weight_decay.
classic_momentum: A `boolean` for whether to use classic (or popular)
momentum. The learning rate is applied during momeuntum update in
classic momentum, but after momentum for popular momentum.
eeta: A `float` for scaling of learning rate when computing trust ratio.
name: The name for the scope.
"""
self.epoch = 0
defaults = dict(
lr=lr,
momentum=momentum,
use_nesterov=use_nesterov,
weight_decay=weight_decay,
exclude_from_weight_decay=exclude_from_weight_decay,
exclude_from_layer_adaptation=exclude_from_layer_adaptation,
classic_momentum=classic_momentum,
eeta=eeta,
)
super(LARS, self).__init__(params, defaults)
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.use_nesterov = use_nesterov
self.classic_momentum = classic_momentum
self.eeta = eeta
self.exclude_from_weight_decay = exclude_from_weight_decay
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
# arg is None.
if exclude_from_layer_adaptation:
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
else:
self.exclude_from_layer_adaptation = exclude_from_weight_decay
def step(self, epoch=None, closure=None):
loss = None
if closure is not None:
loss = closure()
if epoch is None:
epoch = self.epoch
self.epoch += 1
for group in self.param_groups:
weight_decay = group["weight_decay"]
momentum = group["momentum"]
eeta = group["eeta"]
lr = group["lr"]
for p in group["params"]:
if p.grad is None:
continue
param = p.data
grad = p.grad.data
param_state = self.state[p]
# TODO: get param names
# if self._use_weight_decay(param_name):
grad += self.weight_decay * param
if self.classic_momentum:
trust_ratio = 1.0
# TODO: get param names
# if self._do_layer_adaptation(param_name):
w_norm = torch.norm(param)
g_norm = torch.norm(grad)
device = g_norm.get_device()
trust_ratio = torch.where(
w_norm.gt(0),
torch.where(
g_norm.gt(0),
(self.eeta * w_norm / g_norm),
torch.Tensor([1.0]).to(device),
),
torch.Tensor([1.0]).to(device),
).item()
scaled_lr = lr * trust_ratio
if "momentum_buffer" not in param_state:
next_v = param_state["momentum_buffer"] = torch.zeros_like(
p.data
)
else:
next_v = param_state["momentum_buffer"]
next_v.mul_(momentum).add_(scaled_lr, grad)
if self.use_nesterov:
update = (self.momentum * next_v) + (scaled_lr * grad)
else:
update = next_v
p.data.add_(-update)
else:
raise NotImplementedError
return loss
def _use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if not self.weight_decay:
return False
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
def _do_layer_adaptation(self, param_name):
"""Whether to do layer-wise learning rate adaptation for `param_name`."""
if self.exclude_from_layer_adaptation:
for r in self.exclude_from_layer_adaptation:
if re.search(r, param_name) is not None:
return False
return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment