Last active
June 25, 2025 02:34
-
-
Save kstoneriv3/adf1bfb6c5c10ff5700ea544f0bbf6f9 to your computer and use it in GitHub Desktop.
An Optax implementation of Grams, c-Adamw, c-Lion optimizers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import warnings | |
from collections.abc import Callable | |
from typing import Any, NamedTuple, Optional, Union | |
import chex | |
import jax | |
import jax.numpy as jnp | |
import optax.tree | |
from jax import nn | |
from optax._src import ( | |
base, | |
clipping, | |
combine, | |
factorized, | |
numerics, | |
transform, | |
utils, | |
wrappers, | |
) | |
from optax._src import linesearch as _linesearch | |
from optax._src.transform import ScaleByAdamState | |
from optax.transforms import _accumulation, _adding | |
def scale_by_grams( | |
b1: float = 0.9, | |
b2: float = 0.999, | |
eps: float = 1e-8, | |
eps_root: float = 0.0, | |
mu_dtype: Optional[chex.ArrayDType] = None, | |
*, | |
nesterov: bool = False, | |
) -> base.GradientTransformation: | |
r"""Rescale updates according to the Adam algorithm. | |
See :func:`optax.adam` for more details. | |
Args: | |
b1: Decay rate for the exponentially weighted average of grads. | |
b2: Decay rate for the exponentially weighted average of squared grads. | |
eps: Term added to the denominator to improve numerical stability. | |
eps_root: Term added to the denominator inside the square-root to improve | |
numerical stability when backpropagating gradients through the rescaling. | |
mu_dtype: Optional `dtype` to be used for the first order accumulator; if | |
`None` then the `dtype` is inferred from `params` and `updates`. | |
nesterov: Whether to use Nesterov momentum. The variant of Adam with | |
Nesterov momentum is described in [Dozat 2016] | |
Returns: | |
A :class:`optax.GradientTransformation` object. | |
""" | |
mu_dtype = utils.canonicalize_dtype(mu_dtype) | |
def init_fn(params): | |
mu = optax.tree.zeros_like(params, dtype=mu_dtype) # First moment | |
nu = optax.tree.zeros_like(params) # Second moment | |
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) | |
def update_fn(updates, state, params=None): | |
del params | |
mu = optax.tree.update_moment(updates, state.mu, b1, 1) | |
nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2) | |
count_inc = numerics.safe_increment(state.count) | |
if nesterov: | |
mu_hat = jax.tree.map( | |
lambda m, g: b1 * m + (1 - b1) * g, | |
optax.tree.bias_correction(mu, b1, | |
numerics.safe_increment(count_inc)), | |
optax.tree.bias_correction(updates, b1, count_inc), | |
) | |
else: | |
mu_hat = optax.tree.bias_correction(mu, b1, count_inc) | |
# Dozat 2016 https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ | |
# Algorithm 2 further multiplies Adam's standard nu_hat by b2. It is | |
# unclear why. Other Nadam implementations also omit the extra b2 factor. | |
nu_hat = optax.tree.bias_correction(nu, b2, count_inc) | |
updates = jax.tree.map( | |
lambda u, m, v: None if m is None else jnp.sign(u) * jnp.abs(m / (jnp.sqrt(v + eps_root) + eps)), | |
updates, | |
mu_hat, | |
nu_hat, | |
is_leaf=lambda x: x is None, | |
) | |
mu = optax.tree.cast(mu, mu_dtype) | |
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) | |
return base.GradientTransformation(init_fn, update_fn) | |
def grams( | |
learning_rate: base.ScalarOrSchedule, | |
b1: float = 0.9, | |
b2: float = 0.999, | |
eps: float = 1e-8, | |
eps_root: float = 0.0, | |
mu_dtype: Optional[Any] = None, | |
weight_decay: float = 1e-4, | |
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, | |
*, | |
nesterov: bool = False, | |
) -> base.GradientTransformationExtraArgs: | |
r"""Adam with weight decay regularization. | |
AdamW uses weight decay to regularize learning towards small weights, as | |
this leads to better generalization. In SGD you can also use L2 regularization | |
to implement this as an additive loss term, however L2 regularization | |
does not behave as intended for adaptive gradient algorithms such as Adam, | |
see [Loshchilov et al, 2019]. | |
Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`, | |
:math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments | |
``b1``, ``b2``, ``eps`` and ``eps_root`` respectively. The learning rate is | |
indexed by :math:`t` since the learning rate may also be provided by a | |
schedule function. Let :math:`\lambda` be the weight decay and | |
:math:`\theta_t` the parameter vector at time :math:`t`. | |
The ``init`` function of this optimizer initializes an internal state | |
:math:`S_0 := (m_0, v_0) = (0, 0)`, representing initial estimates for the | |
first and second moments. In practice these values are stored as pytrees | |
containing all zeros, with the same shape as the model updates. | |
At step :math:`t`, the ``update`` function of this optimizer takes as | |
arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t` | |
and the parameters :math:`\theta_t` and computes updates :math:`u_t` and | |
new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have, | |
.. math:: | |
\begin{align*} | |
m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ | |
v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ | |
\hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ | |
\hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ | |
u_t &\leftarrow -\alpha_t \cdot \left( \hat{m}_t / \left({\sqrt{\hat{v}_t | |
+ \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t} \right)\\ | |
S_t &\leftarrow (m_t, v_t). | |
\end{align*} | |
This implementation can incorporate a momentum a la Nesterov introduced by | |
[Dozat 2016]. The resulting optimizer is then often referred as NAdamW. | |
With the keyword argument `nesterov=True`, the optimizer uses Nesterov | |
momentum, replacing the above :math:`\hat{m}_t` with | |
.. math:: | |
\hat{m}_t \leftarrow | |
\beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. | |
Args: | |
learning_rate: A global scaling factor, either fixed or evolving along | |
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. | |
b1: Exponential decay rate to track the first moment of past gradients. | |
b2: Exponential decay rate to track the second moment of past gradients. | |
eps: A small constant applied to denominator outside of the square root | |
(as in the Adam paper) to avoid dividing by zero when rescaling. | |
eps_root: A small constant applied to denominator inside the square root (as | |
in RMSProp), to avoid dividing by zero when rescaling. This is needed for | |
instance when computing (meta-)gradients through Adam. | |
mu_dtype: Optional `dtype` to be used for the first order accumulator; if | |
`None` then the `dtype` is inferred from `params` and `updates`. | |
weight_decay: Strength of the weight decay regularization. Note that this | |
weight decay is multiplied with the learning rate. This is consistent | |
with other frameworks such as PyTorch, but different from | |
(Loshchilov et al, 2019) where the weight decay is only multiplied with | |
the "schedule multiplier", but not the base learning rate. | |
mask: A tree with same structure as (or a prefix of) the params PyTree, | |
or a Callable that returns such a pytree given the params/updates. | |
The leaves should be booleans, `True` for leaves/subtrees you want to | |
apply the weight decay to, and `False` for those you want to skip. Note | |
that the Adam gradient transformations are applied to all parameters. | |
nesterov: Whether to use Nesterov momentum. The solver with | |
nesterov=True is equivalent to the :func:`optax.nadamw` optimizer. This | |
modification is described in [Dozat 2016]. | |
Returns: | |
The corresponding :class:`optax.GradientTransformationExtraArgs`. | |
Examples: | |
>>> import optax | |
>>> import jax | |
>>> import jax.numpy as jnp | |
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function | |
>>> solver = optax.adamw(learning_rate=0.003) | |
>>> params = jnp.array([1., 2., 3.]) | |
>>> print('Objective function: ', f(params)) | |
Objective function: 14.0 | |
>>> opt_state = solver.init(params) | |
>>> for _ in range(5): | |
... grad = jax.grad(f)(params) | |
... updates, opt_state = solver.update(grad, opt_state, params) | |
... params = optax.apply_updates(params, updates) | |
... print('Objective function: {:.2E}'.format(f(params))) | |
Objective function: 1.40E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.38E+01 | |
References: | |
Loshchilov et al, `Decoupled Weight Decay | |
Regularization <https://arxiv.org/abs/1711.05101>`_, 2019 | |
Dozat, `Incorporating Nesterov Momentum into Adam | |
<https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ>`_, 2016 | |
.. seealso:: | |
See the related functions :func:`optax.adam`, :func:`optax.nadamw`, as well | |
as the example :doc:`../_collections/examples/nanolm` for a use case. | |
""" | |
return combine.chain( | |
scale_by_grams( | |
b1=b1, | |
b2=b2, | |
eps=eps, | |
eps_root=eps_root, | |
mu_dtype=mu_dtype, | |
nesterov=nesterov, | |
), | |
transform.add_decayed_weights(weight_decay, mask), | |
transform.scale_by_learning_rate(learning_rate), | |
) | |
def scale_by_cadam( | |
b1: float = 0.9, | |
b2: float = 0.999, | |
eps: float = 1e-8, | |
eps_root: float = 0.0, | |
mu_dtype: Optional[chex.ArrayDType] = None, | |
*, | |
nesterov: bool = False, | |
) -> base.GradientTransformation: | |
r"""Rescale updates according to the Adam algorithm. | |
See :func:`optax.adam` for more details. | |
Args: | |
b1: Decay rate for the exponentially weighted average of grads. | |
b2: Decay rate for the exponentially weighted average of squared grads. | |
eps: Term added to the denominator to improve numerical stability. | |
eps_root: Term added to the denominator inside the square-root to improve | |
numerical stability when backpropagating gradients through the rescaling. | |
mu_dtype: Optional `dtype` to be used for the first order accumulator; if | |
`None` then the `dtype` is inferred from `params` and `updates`. | |
nesterov: Whether to use Nesterov momentum. The variant of Adam with | |
Nesterov momentum is described in [Dozat 2016] | |
Returns: | |
A :class:`optax.GradientTransformation` object. | |
""" | |
mu_dtype = utils.canonicalize_dtype(mu_dtype) | |
def init_fn(params): | |
mu = optax.tree.zeros_like(params, dtype=mu_dtype) # First moment | |
nu = optax.tree.zeros_like(params) # Second moment | |
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) | |
def update_fn(updates, state, params=None): | |
del params | |
mu = optax.tree.update_moment(updates, state.mu, b1, 1) | |
nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2) | |
count_inc = numerics.safe_increment(state.count) | |
if nesterov: | |
mu_hat = jax.tree.map( | |
lambda m, g: b1 * m + (1 - b1) * g, | |
optax.tree.bias_correction(mu, b1, | |
numerics.safe_increment(count_inc)), | |
optax.tree.bias_correction(updates, b1, count_inc), | |
) | |
else: | |
mu_hat = optax.tree.bias_correction(mu, b1, count_inc) | |
# Dozat 2016 https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ | |
# Algorithm 2 further multiplies Adam's standard nu_hat by b2. It is | |
# unclear why. Other Nadam implementations also omit the extra b2 factor. | |
nu_hat = optax.tree.bias_correction(nu, b2, count_inc) | |
updates = jax.tree.map( | |
lambda u, m, v: None if m is None else (u * m > 0) * jnp.abs(m / (jnp.sqrt(v + eps_root) + eps)), | |
updates, | |
mu_hat, | |
nu_hat, | |
is_leaf=lambda x: x is None, | |
) | |
mu = optax.tree.cast(mu, mu_dtype) | |
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) | |
return base.GradientTransformation(init_fn, update_fn) | |
def cadamw( | |
learning_rate: base.ScalarOrSchedule, | |
b1: float = 0.9, | |
b2: float = 0.999, | |
eps: float = 1e-8, | |
eps_root: float = 0.0, | |
mu_dtype: Optional[Any] = None, | |
weight_decay: float = 1e-4, | |
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, | |
*, | |
nesterov: bool = False, | |
) -> base.GradientTransformationExtraArgs: | |
r"""Adam with weight decay regularization. | |
AdamW uses weight decay to regularize learning towards small weights, as | |
this leads to better generalization. In SGD you can also use L2 regularization | |
to implement this as an additive loss term, however L2 regularization | |
does not behave as intended for adaptive gradient algorithms such as Adam, | |
see [Loshchilov et al, 2019]. | |
Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`, | |
:math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments | |
``b1``, ``b2``, ``eps`` and ``eps_root`` respectively. The learning rate is | |
indexed by :math:`t` since the learning rate may also be provided by a | |
schedule function. Let :math:`\lambda` be the weight decay and | |
:math:`\theta_t` the parameter vector at time :math:`t`. | |
The ``init`` function of this optimizer initializes an internal state | |
:math:`S_0 := (m_0, v_0) = (0, 0)`, representing initial estimates for the | |
first and second moments. In practice these values are stored as pytrees | |
containing all zeros, with the same shape as the model updates. | |
At step :math:`t`, the ``update`` function of this optimizer takes as | |
arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t` | |
and the parameters :math:`\theta_t` and computes updates :math:`u_t` and | |
new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have, | |
.. math:: | |
\begin{align*} | |
m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ | |
v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ | |
\hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ | |
\hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ | |
u_t &\leftarrow -\alpha_t \cdot \left( \hat{m}_t / \left({\sqrt{\hat{v}_t | |
+ \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t} \right)\\ | |
S_t &\leftarrow (m_t, v_t). | |
\end{align*} | |
This implementation can incorporate a momentum a la Nesterov introduced by | |
[Dozat 2016]. The resulting optimizer is then often referred as NAdamW. | |
With the keyword argument `nesterov=True`, the optimizer uses Nesterov | |
momentum, replacing the above :math:`\hat{m}_t` with | |
.. math:: | |
\hat{m}_t \leftarrow | |
\beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. | |
Args: | |
learning_rate: A global scaling factor, either fixed or evolving along | |
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. | |
b1: Exponential decay rate to track the first moment of past gradients. | |
b2: Exponential decay rate to track the second moment of past gradients. | |
eps: A small constant applied to denominator outside of the square root | |
(as in the Adam paper) to avoid dividing by zero when rescaling. | |
eps_root: A small constant applied to denominator inside the square root (as | |
in RMSProp), to avoid dividing by zero when rescaling. This is needed for | |
instance when computing (meta-)gradients through Adam. | |
mu_dtype: Optional `dtype` to be used for the first order accumulator; if | |
`None` then the `dtype` is inferred from `params` and `updates`. | |
weight_decay: Strength of the weight decay regularization. Note that this | |
weight decay is multiplied with the learning rate. This is consistent | |
with other frameworks such as PyTorch, but different from | |
(Loshchilov et al, 2019) where the weight decay is only multiplied with | |
the "schedule multiplier", but not the base learning rate. | |
mask: A tree with same structure as (or a prefix of) the params PyTree, | |
or a Callable that returns such a pytree given the params/updates. | |
The leaves should be booleans, `True` for leaves/subtrees you want to | |
apply the weight decay to, and `False` for those you want to skip. Note | |
that the Adam gradient transformations are applied to all parameters. | |
nesterov: Whether to use Nesterov momentum. The solver with | |
nesterov=True is equivalent to the :func:`optax.nadamw` optimizer. This | |
modification is described in [Dozat 2016]. | |
Returns: | |
The corresponding :class:`optax.GradientTransformationExtraArgs`. | |
Examples: | |
>>> import optax | |
>>> import jax | |
>>> import jax.numpy as jnp | |
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function | |
>>> solver = optax.adamw(learning_rate=0.003) | |
>>> params = jnp.array([1., 2., 3.]) | |
>>> print('Objective function: ', f(params)) | |
Objective function: 14.0 | |
>>> opt_state = solver.init(params) | |
>>> for _ in range(5): | |
... grad = jax.grad(f)(params) | |
... updates, opt_state = solver.update(grad, opt_state, params) | |
... params = optax.apply_updates(params, updates) | |
... print('Objective function: {:.2E}'.format(f(params))) | |
Objective function: 1.40E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.38E+01 | |
References: | |
Loshchilov et al, `Decoupled Weight Decay | |
Regularization <https://arxiv.org/abs/1711.05101>`_, 2019 | |
Dozat, `Incorporating Nesterov Momentum into Adam | |
<https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ>`_, 2016 | |
.. seealso:: | |
See the related functions :func:`optax.adam`, :func:`optax.nadamw`, as well | |
as the example :doc:`../_collections/examples/nanolm` for a use case. | |
""" | |
return combine.chain( | |
scale_by_cadam( | |
b1=b1, | |
b2=b2, | |
eps=eps, | |
eps_root=eps_root, | |
mu_dtype=mu_dtype, | |
nesterov=nesterov, | |
), | |
transform.add_decayed_weights(weight_decay, mask), | |
transform.scale_by_learning_rate(learning_rate), | |
) | |
class ScaleByLionState(NamedTuple): | |
"""State for the Lion algorithm.""" | |
count: chex.Array # shape=(), dtype=jnp.int32. | |
mu: base.Updates | |
def scale_by_clion( | |
b1: float = 0.9, | |
b2: float = 0.99, | |
mu_dtype: Optional[chex.ArrayDType] = None, | |
) -> base.GradientTransformation: | |
"""Rescale updates according to the Lion algorithm. | |
See :func:`optax.lion` for more details. | |
Args: | |
b1: Rate for combining the momentum and the current grad. | |
b2: Decay rate for the exponentially weighted average of grads. | |
mu_dtype: Optional `dtype` to be used for the momentum; if `None` then the | |
`dtype is inferred from `params` and `updates`. | |
Returns: | |
A :class:`optax.GradientTransformation` object. | |
""" | |
mu_dtype = utils.canonicalize_dtype(mu_dtype) | |
def init_fn(params): | |
mu = optax.tree.zeros_like(params, dtype=mu_dtype) # moment | |
return ScaleByLionState(count=jnp.zeros([], jnp.int32), mu=mu) | |
def update_fn(updates, state, params=None): | |
del params | |
def calc_update(g, m): | |
u = (1.0 - b1) * g + b1 * m | |
return (u * g > 0) * jnp.sign(u) | |
updates_new = jax.tree.map( | |
calc_update, updates, state.mu | |
) | |
mu = optax.tree.update_moment(updates, state.mu, b2, 1) | |
mu = optax.tree.cast(mu, mu_dtype) | |
count_inc = numerics.safe_increment(state.count) | |
return updates_new, ScaleByLionState(count=count_inc, mu=mu) | |
return base.GradientTransformation(init_fn, update_fn) | |
def clion( | |
learning_rate: base.ScalarOrSchedule, | |
b1: float = 0.9, | |
b2: float = 0.99, | |
mu_dtype: Optional[Any] = None, | |
weight_decay: float = 1e-3, | |
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, | |
) -> base.GradientTransformationExtraArgs: | |
r"""The Lion optimizer. | |
Lion is discovered by symbolic program search. Unlike most adaptive optimizers | |
such as AdamW, Lion only tracks momentum, making it more memory-efficient. | |
The update of Lion is produced through the sign operation, resulting in a | |
larger norm compared to updates produced by other optimizers such as SGD and | |
AdamW. A suitable learning rate for Lion is typically 3-10x smaller than that | |
for AdamW, the weight decay for Lion should be in turn 3-10x larger than that | |
for AdamW to maintain a similar strength (lr * wd). | |
Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`, | |
represent the arguments ``b1`` and ``b2`` respectively. The learning rate is | |
indexed by :math:`t` since the learning rate may also be provided by a | |
schedule function. Let :math:`\lambda` be the weight decay and | |
:math:`\theta_t` the parameter vector at time :math:`t`. | |
The ``init`` function of this optimizer initializes an internal state | |
:math:`S_0 := (m_0) = (0)`, representing the intial estimate for the | |
first moment. In practice these values are stored as pytrees | |
containing all zeros, with the same shape as the model updates. | |
At step :math:`t`, the ``update`` function of this optimizer takes as | |
arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t` | |
and the parameters :math:`\theta_t` and computes updates :math:`u_t` and | |
new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have, | |
.. math:: | |
\begin{align*} | |
c_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ | |
u_t &\leftarrow -\alpha_t \cdot \left( sign \left( c_t \right) + | |
\lambda \theta_{t} \right)\\ | |
m_t &\leftarrow \beta_2 \cdot m_{t-1} + (1-\beta_2) \cdot g_t \\ | |
S_t &\leftarrow (m_t). | |
\end{align*} | |
Args: | |
learning_rate: A global scaling factor, either fixed or evolving along | |
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. | |
b1: Rate to combine the momentum and the current gradient. | |
b2: Exponential decay rate to track the momentum of past gradients. | |
mu_dtype: Optional `dtype` to be used for the first order accumulator; if | |
`None` then the `dtype` is inferred from `params` and `updates`. | |
weight_decay: Strength of the weight decay regularization. Note that this | |
weight decay is multiplied with the learning rate. This is consistent with | |
other frameworks such as PyTorch, but different from (Loshchilov et al, | |
2019) where the weight decay is only multiplied with the "schedule | |
multiplier", but not the base learning rate. | |
mask: A tree with same structure as (or a prefix of) the params PyTree, or a | |
Callable that returns such a pytree given the params/updates. The leaves | |
should be booleans, `True` for leaves/subtrees you want to apply the | |
weight decay to, and `False` for those you want to skip. Note that the | |
Adam gradient transformations are applied to all parameters. | |
Returns: | |
The corresponding :class:`optax.GradientTransformationExtraArgs`. | |
Examples: | |
>>> import optax | |
>>> import jax | |
>>> import jax.numpy as jnp | |
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function | |
>>> solver = optax.clion(learning_rate=0.003) | |
>>> params = jnp.array([1., 2., 3.]) | |
>>> print('Objective function: ', f(params)) | |
Objective function: 14.0 | |
>>> opt_state = solver.init(params) | |
>>> for _ in range(5): | |
... grad = jax.grad(f)(params) | |
... updates, opt_state = solver.update(grad, opt_state, params) | |
... params = optax.apply_updates(params, updates) | |
... print('Objective function: {:.2E}'.format(f(params))) | |
Objective function: 1.40E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.38E+01 | |
References: | |
Chen et al, `Symbolic Discovery of Optimization Algorithms | |
<https://arxiv.org/abs/2302.06675>`_, 2023 | |
""" | |
return combine.chain( | |
transform.scale_by_clion(b1=b1, b2=b2, mu_dtype=mu_dtype), | |
transform.add_decayed_weights(weight_decay, mask), | |
transform.scale_by_learning_rate(learning_rate), | |
) | |
def amsgrad( | |
learning_rate: base.ScalarOrSchedule, | |
b1: float = 0.9, | |
b2: float = 0.999, | |
eps: float = 1e-8, | |
eps_root: float = 0.0, | |
mu_dtype: Optional[Any] = None, | |
) -> base.GradientTransformationExtraArgs: | |
"""The AMSGrad optimizer. | |
The original Adam can fail to converge to the optimal solution in some cases. | |
AMSGrad guarantees convergence by using a long-term memory of past gradients. | |
Args: | |
learning_rate: A global scaling factor, either fixed or evolving along | |
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. | |
b1: Exponential decay rate to track the first moment of past gradients. | |
b2: Exponential decay rate to track the second moment of past gradients. | |
eps: A small constant applied to denominator outside of the square root (as | |
in the Adam paper) to avoid dividing by zero when rescaling. | |
eps_root: A small constant applied to denominator inside the square root (as | |
in RMSProp), to avoid dividing by zero when rescaling. This is needed for | |
instance when computing (meta-)gradients through Adam. | |
mu_dtype: Optional `dtype` to be used for the first order accumulator; if | |
`None` then the `dtype` is inferred from `params` and `updates`. | |
Returns: | |
The corresponding :class:`optax.GradientTransformationExtraArgs`. | |
Examples: | |
>>> import optax | |
>>> import jax | |
>>> import jax.numpy as jnp | |
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function | |
>>> solver = optax.amsgrad(learning_rate=0.003) | |
>>> params = jnp.array([1., 2., 3.]) | |
>>> print('Objective function: ', f(params)) | |
Objective function: 14.0 | |
>>> opt_state = solver.init(params) | |
>>> for _ in range(5): | |
... grad = jax.grad(f)(params) | |
... updates, opt_state = solver.update(grad, opt_state, params) | |
... params = optax.apply_updates(params, updates) | |
... print('Objective function: {:.2E}'.format(f(params))) | |
Objective function: 1.40E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.39E+01 | |
Objective function: 1.38E+01 | |
References: | |
Reddi et al, `On the Convergence of Adam and Beyond | |
<https://openreview.net/forum?id=ryQu7f-RZ>`_, 2023 | |
""" | |
return combine.chain( | |
transform.scale_by_amsgrad( | |
b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype | |
), | |
transform.scale_by_learning_rate(learning_rate), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
TODO: modify the docstrings, which are copied and pasted from the optax code of AdamW and Lion.