Skip to content

Instantly share code, notes, and snippets.

@odiak
Created June 14, 2019 08:21
Show Gist options
  • Select an option

  • Save odiak/338708109c49a75236b6161b4885a8cd to your computer and use it in GitHub Desktop.

Select an option

Save odiak/338708109c49a75236b6161b4885a8cd to your computer and use it in GitHub Desktop.
import numpy as np
from dataclasses import dataclass
@dataclass(frozen=True)
class Adam:
learning_rate: float = 1.0
eps: float = 1e-8
rho1: float = 0.9
rho2: float = 0.999
def calc_update(
self, grad: np.ndarray, m: np.ndarray, v: np.ndarray, t: int = 1
) -> np.ndarray:
t = max(t, 1)
m[:] = m * self.rho1 + grad * (1 - self.rho1)
v[:] = v * self.rho2 + np.square(grad) * (1 - self.rho2)
m_ = m / (1 - self.rho1 ** t)
v_ = v / (1 - self.rho1 ** t)
update = m_ / np.sqrt(v_ + self.eps)
return -self.learning_rate * update
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment