Skip to content

Instantly share code, notes, and snippets.

@enochkan
Last active December 31, 2022 16:16
Show Gist options
  • Save enochkan/56af870bd19884f189639a0cb3381ff4 to your computer and use it in GitHub Desktop.
Save enochkan/56af870bd19884f189639a0cb3381ff4 to your computer and use it in GitHub Desktop.
class definition of adam optimizer
import numpy as np
class AdamOptim():
def __init__(self, eta=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8):
self.m_dw, self.v_dw = 0, 0
self.m_db, self.v_db = 0, 0
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.eta = eta
def update(self, t, w, b, dw, db):
## dw, db are from current minibatch
## momentum beta 1
# *** weights *** #
self.m_dw = self.beta1*self.m_dw + (1-self.beta1)*dw
# *** biases *** #
self.m_db = self.beta1*self.m_db + (1-self.beta1)*db
## rms beta 2
# *** weights *** #
self.v_dw = self.beta2*self.v_dw + (1-self.beta2)*(dw**2)
# *** biases *** #
self.v_db = self.beta2*self.v_db + (1-self.beta2)*(db)
## bias correction
m_dw_corr = self.m_dw/(1-self.beta1**t)
m_db_corr = self.m_db/(1-self.beta1**t)
v_dw_corr = self.v_dw/(1-self.beta2**t)
v_db_corr = self.v_db/(1-self.beta2**t)
## update weights and biases
w = w - self.eta*(m_dw_corr/(np.sqrt(v_dw_corr)+self.epsilon))
b = b - self.eta*(m_db_corr/(np.sqrt(v_db_corr)+self.epsilon))
return w, b
@Samuel-Bachorik
Copy link

Samuel-Bachorik commented Dec 31, 2022

Hi, what is "t" parameter in update function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment