Skip to content

Instantly share code, notes, and snippets.

@kstoneriv3
Last active June 25, 2025 02:34
Show Gist options
  • Save kstoneriv3/adf1bfb6c5c10ff5700ea544f0bbf6f9 to your computer and use it in GitHub Desktop.
Save kstoneriv3/adf1bfb6c5c10ff5700ea544f0bbf6f9 to your computer and use it in GitHub Desktop.
An Optax implementation of Grams, c-Adamw, c-Lion optimizers
import functools
import warnings
from collections.abc import Callable
from typing import Any, NamedTuple, Optional, Union
import chex
import jax
import jax.numpy as jnp
import optax.tree
from jax import nn
from optax._src import (
base,
clipping,
combine,
factorized,
numerics,
transform,
utils,
wrappers,
)
from optax._src import linesearch as _linesearch
from optax._src.transform import ScaleByAdamState
from optax.transforms import _accumulation, _adding
def scale_by_grams(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
*,
nesterov: bool = False,
) -> base.GradientTransformation:
r"""Rescale updates according to the Adam algorithm.
See :func:`optax.adam` for more details.
Args:
b1: Decay rate for the exponentially weighted average of grads.
b2: Decay rate for the exponentially weighted average of squared grads.
eps: Term added to the denominator to improve numerical stability.
eps_root: Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
nesterov: Whether to use Nesterov momentum. The variant of Adam with
Nesterov momentum is described in [Dozat 2016]
Returns:
A :class:`optax.GradientTransformation` object.
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)
def init_fn(params):
mu = optax.tree.zeros_like(params, dtype=mu_dtype) # First moment
nu = optax.tree.zeros_like(params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
def update_fn(updates, state, params=None):
del params
mu = optax.tree.update_moment(updates, state.mu, b1, 1)
nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = numerics.safe_increment(state.count)
if nesterov:
mu_hat = jax.tree.map(
lambda m, g: b1 * m + (1 - b1) * g,
optax.tree.bias_correction(mu, b1,
numerics.safe_increment(count_inc)),
optax.tree.bias_correction(updates, b1, count_inc),
)
else:
mu_hat = optax.tree.bias_correction(mu, b1, count_inc)
# Dozat 2016 https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ
# Algorithm 2 further multiplies Adam's standard nu_hat by b2. It is
# unclear why. Other Nadam implementations also omit the extra b2 factor.
nu_hat = optax.tree.bias_correction(nu, b2, count_inc)
updates = jax.tree.map(
lambda u, m, v: None if m is None else jnp.sign(u) * jnp.abs(m / (jnp.sqrt(v + eps_root) + eps)),
updates,
mu_hat,
nu_hat,
is_leaf=lambda x: x is None,
)
mu = optax.tree.cast(mu, mu_dtype)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
return base.GradientTransformation(init_fn, update_fn)
def grams(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[Any] = None,
weight_decay: float = 1e-4,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
*,
nesterov: bool = False,
) -> base.GradientTransformationExtraArgs:
r"""Adam with weight decay regularization.
AdamW uses weight decay to regularize learning towards small weights, as
this leads to better generalization. In SGD you can also use L2 regularization
to implement this as an additive loss term, however L2 regularization
does not behave as intended for adaptive gradient algorithms such as Adam,
see [Loshchilov et al, 2019].
Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`,
:math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments
``b1``, ``b2``, ``eps`` and ``eps_root`` respectively. The learning rate is
indexed by :math:`t` since the learning rate may also be provided by a
schedule function. Let :math:`\lambda` be the weight decay and
:math:`\theta_t` the parameter vector at time :math:`t`.
The ``init`` function of this optimizer initializes an internal state
:math:`S_0 := (m_0, v_0) = (0, 0)`, representing initial estimates for the
first and second moments. In practice these values are stored as pytrees
containing all zeros, with the same shape as the model updates.
At step :math:`t`, the ``update`` function of this optimizer takes as
arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t`
and the parameters :math:`\theta_t` and computes updates :math:`u_t` and
new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have,
.. math::
\begin{align*}
m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\
v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\
\hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\
\hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\
u_t &\leftarrow -\alpha_t \cdot \left( \hat{m}_t / \left({\sqrt{\hat{v}_t
+ \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t} \right)\\
S_t &\leftarrow (m_t, v_t).
\end{align*}
This implementation can incorporate a momentum a la Nesterov introduced by
[Dozat 2016]. The resulting optimizer is then often referred as NAdamW.
With the keyword argument `nesterov=True`, the optimizer uses Nesterov
momentum, replacing the above :math:`\hat{m}_t` with
.. math::
\hat{m}_t \leftarrow
\beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the first moment of past gradients.
b2: Exponential decay rate to track the second moment of past gradients.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
weight_decay: Strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent
with other frameworks such as PyTorch, but different from
(Loshchilov et al, 2019) where the weight decay is only multiplied with
the "schedule multiplier", but not the base learning rate.
mask: A tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.
nesterov: Whether to use Nesterov momentum. The solver with
nesterov=True is equivalent to the :func:`optax.nadamw` optimizer. This
modification is described in [Dozat 2016].
Returns:
The corresponding :class:`optax.GradientTransformationExtraArgs`.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.adamw(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
References:
Loshchilov et al, `Decoupled Weight Decay
Regularization <https://arxiv.org/abs/1711.05101>`_, 2019
Dozat, `Incorporating Nesterov Momentum into Adam
<https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ>`_, 2016
.. seealso::
See the related functions :func:`optax.adam`, :func:`optax.nadamw`, as well
as the example :doc:`../_collections/examples/nanolm` for a use case.
"""
return combine.chain(
scale_by_grams(
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
mu_dtype=mu_dtype,
nesterov=nesterov,
),
transform.add_decayed_weights(weight_decay, mask),
transform.scale_by_learning_rate(learning_rate),
)
def scale_by_cadam(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
*,
nesterov: bool = False,
) -> base.GradientTransformation:
r"""Rescale updates according to the Adam algorithm.
See :func:`optax.adam` for more details.
Args:
b1: Decay rate for the exponentially weighted average of grads.
b2: Decay rate for the exponentially weighted average of squared grads.
eps: Term added to the denominator to improve numerical stability.
eps_root: Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
nesterov: Whether to use Nesterov momentum. The variant of Adam with
Nesterov momentum is described in [Dozat 2016]
Returns:
A :class:`optax.GradientTransformation` object.
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)
def init_fn(params):
mu = optax.tree.zeros_like(params, dtype=mu_dtype) # First moment
nu = optax.tree.zeros_like(params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
def update_fn(updates, state, params=None):
del params
mu = optax.tree.update_moment(updates, state.mu, b1, 1)
nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = numerics.safe_increment(state.count)
if nesterov:
mu_hat = jax.tree.map(
lambda m, g: b1 * m + (1 - b1) * g,
optax.tree.bias_correction(mu, b1,
numerics.safe_increment(count_inc)),
optax.tree.bias_correction(updates, b1, count_inc),
)
else:
mu_hat = optax.tree.bias_correction(mu, b1, count_inc)
# Dozat 2016 https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ
# Algorithm 2 further multiplies Adam's standard nu_hat by b2. It is
# unclear why. Other Nadam implementations also omit the extra b2 factor.
nu_hat = optax.tree.bias_correction(nu, b2, count_inc)
updates = jax.tree.map(
lambda u, m, v: None if m is None else (u * m > 0) * jnp.abs(m / (jnp.sqrt(v + eps_root) + eps)),
updates,
mu_hat,
nu_hat,
is_leaf=lambda x: x is None,
)
mu = optax.tree.cast(mu, mu_dtype)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
return base.GradientTransformation(init_fn, update_fn)
def cadamw(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[Any] = None,
weight_decay: float = 1e-4,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
*,
nesterov: bool = False,
) -> base.GradientTransformationExtraArgs:
r"""Adam with weight decay regularization.
AdamW uses weight decay to regularize learning towards small weights, as
this leads to better generalization. In SGD you can also use L2 regularization
to implement this as an additive loss term, however L2 regularization
does not behave as intended for adaptive gradient algorithms such as Adam,
see [Loshchilov et al, 2019].
Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`,
:math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments
``b1``, ``b2``, ``eps`` and ``eps_root`` respectively. The learning rate is
indexed by :math:`t` since the learning rate may also be provided by a
schedule function. Let :math:`\lambda` be the weight decay and
:math:`\theta_t` the parameter vector at time :math:`t`.
The ``init`` function of this optimizer initializes an internal state
:math:`S_0 := (m_0, v_0) = (0, 0)`, representing initial estimates for the
first and second moments. In practice these values are stored as pytrees
containing all zeros, with the same shape as the model updates.
At step :math:`t`, the ``update`` function of this optimizer takes as
arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t`
and the parameters :math:`\theta_t` and computes updates :math:`u_t` and
new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have,
.. math::
\begin{align*}
m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\
v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\
\hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\
\hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\
u_t &\leftarrow -\alpha_t \cdot \left( \hat{m}_t / \left({\sqrt{\hat{v}_t
+ \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t} \right)\\
S_t &\leftarrow (m_t, v_t).
\end{align*}
This implementation can incorporate a momentum a la Nesterov introduced by
[Dozat 2016]. The resulting optimizer is then often referred as NAdamW.
With the keyword argument `nesterov=True`, the optimizer uses Nesterov
momentum, replacing the above :math:`\hat{m}_t` with
.. math::
\hat{m}_t \leftarrow
\beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the first moment of past gradients.
b2: Exponential decay rate to track the second moment of past gradients.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
weight_decay: Strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent
with other frameworks such as PyTorch, but different from
(Loshchilov et al, 2019) where the weight decay is only multiplied with
the "schedule multiplier", but not the base learning rate.
mask: A tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.
nesterov: Whether to use Nesterov momentum. The solver with
nesterov=True is equivalent to the :func:`optax.nadamw` optimizer. This
modification is described in [Dozat 2016].
Returns:
The corresponding :class:`optax.GradientTransformationExtraArgs`.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.adamw(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
References:
Loshchilov et al, `Decoupled Weight Decay
Regularization <https://arxiv.org/abs/1711.05101>`_, 2019
Dozat, `Incorporating Nesterov Momentum into Adam
<https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ>`_, 2016
.. seealso::
See the related functions :func:`optax.adam`, :func:`optax.nadamw`, as well
as the example :doc:`../_collections/examples/nanolm` for a use case.
"""
return combine.chain(
scale_by_cadam(
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
mu_dtype=mu_dtype,
nesterov=nesterov,
),
transform.add_decayed_weights(weight_decay, mask),
transform.scale_by_learning_rate(learning_rate),
)
class ScaleByLionState(NamedTuple):
"""State for the Lion algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: base.Updates
def scale_by_clion(
b1: float = 0.9,
b2: float = 0.99,
mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:
"""Rescale updates according to the Lion algorithm.
See :func:`optax.lion` for more details.
Args:
b1: Rate for combining the momentum and the current grad.
b2: Decay rate for the exponentially weighted average of grads.
mu_dtype: Optional `dtype` to be used for the momentum; if `None` then the
`dtype is inferred from `params` and `updates`.
Returns:
A :class:`optax.GradientTransformation` object.
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)
def init_fn(params):
mu = optax.tree.zeros_like(params, dtype=mu_dtype) # moment
return ScaleByLionState(count=jnp.zeros([], jnp.int32), mu=mu)
def update_fn(updates, state, params=None):
del params
def calc_update(g, m):
u = (1.0 - b1) * g + b1 * m
return (u * g > 0) * jnp.sign(u)
updates_new = jax.tree.map(
calc_update, updates, state.mu
)
mu = optax.tree.update_moment(updates, state.mu, b2, 1)
mu = optax.tree.cast(mu, mu_dtype)
count_inc = numerics.safe_increment(state.count)
return updates_new, ScaleByLionState(count=count_inc, mu=mu)
return base.GradientTransformation(init_fn, update_fn)
def clion(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.99,
mu_dtype: Optional[Any] = None,
weight_decay: float = 1e-3,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformationExtraArgs:
r"""The Lion optimizer.
Lion is discovered by symbolic program search. Unlike most adaptive optimizers
such as AdamW, Lion only tracks momentum, making it more memory-efficient.
The update of Lion is produced through the sign operation, resulting in a
larger norm compared to updates produced by other optimizers such as SGD and
AdamW. A suitable learning rate for Lion is typically 3-10x smaller than that
for AdamW, the weight decay for Lion should be in turn 3-10x larger than that
for AdamW to maintain a similar strength (lr * wd).
Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`,
represent the arguments ``b1`` and ``b2`` respectively. The learning rate is
indexed by :math:`t` since the learning rate may also be provided by a
schedule function. Let :math:`\lambda` be the weight decay and
:math:`\theta_t` the parameter vector at time :math:`t`.
The ``init`` function of this optimizer initializes an internal state
:math:`S_0 := (m_0) = (0)`, representing the intial estimate for the
first moment. In practice these values are stored as pytrees
containing all zeros, with the same shape as the model updates.
At step :math:`t`, the ``update`` function of this optimizer takes as
arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t`
and the parameters :math:`\theta_t` and computes updates :math:`u_t` and
new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have,
.. math::
\begin{align*}
c_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\
u_t &\leftarrow -\alpha_t \cdot \left( sign \left( c_t \right) +
\lambda \theta_{t} \right)\\
m_t &\leftarrow \beta_2 \cdot m_{t-1} + (1-\beta_2) \cdot g_t \\
S_t &\leftarrow (m_t).
\end{align*}
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Rate to combine the momentum and the current gradient.
b2: Exponential decay rate to track the momentum of past gradients.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
weight_decay: Strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent with
other frameworks such as PyTorch, but different from (Loshchilov et al,
2019) where the weight decay is only multiplied with the "schedule
multiplier", but not the base learning rate.
mask: A tree with same structure as (or a prefix of) the params PyTree, or a
Callable that returns such a pytree given the params/updates. The leaves
should be booleans, `True` for leaves/subtrees you want to apply the
weight decay to, and `False` for those you want to skip. Note that the
Adam gradient transformations are applied to all parameters.
Returns:
The corresponding :class:`optax.GradientTransformationExtraArgs`.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.clion(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
References:
Chen et al, `Symbolic Discovery of Optimization Algorithms
<https://arxiv.org/abs/2302.06675>`_, 2023
"""
return combine.chain(
transform.scale_by_clion(b1=b1, b2=b2, mu_dtype=mu_dtype),
transform.add_decayed_weights(weight_decay, mask),
transform.scale_by_learning_rate(learning_rate),
)
def amsgrad(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[Any] = None,
) -> base.GradientTransformationExtraArgs:
"""The AMSGrad optimizer.
The original Adam can fail to converge to the optimal solution in some cases.
AMSGrad guarantees convergence by using a long-term memory of past gradients.
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the first moment of past gradients.
b2: Exponential decay rate to track the second moment of past gradients.
eps: A small constant applied to denominator outside of the square root (as
in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
Returns:
The corresponding :class:`optax.GradientTransformationExtraArgs`.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.amsgrad(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
References:
Reddi et al, `On the Convergence of Adam and Beyond
<https://openreview.net/forum?id=ryQu7f-RZ>`_, 2023
"""
return combine.chain(
transform.scale_by_amsgrad(
b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype
),
transform.scale_by_learning_rate(learning_rate),
)
@kstoneriv3
Copy link
Author

TODO: modify the docstrings, which are copied and pasted from the optax code of AdamW and Lion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment