-
-
Save kirk86/8f0e6f007ee9c179d0f72cf52456afdc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def adam(loss, all_params, learning_rate=0.001, b1=0.9, b2=0.999, e=1e-8, | |
gamma=1-1e-8): | |
""" | |
ADAM update rules | |
Default values are taken from [Kingma2014] | |
References: | |
[Kingma2014] Kingma, Diederik, and Jimmy Ba. | |
"Adam: A Method for Stochastic Optimization." | |
arXiv preprint arXiv:1412.6980 (2014). | |
http://arxiv.org/pdf/1412.6980v4.pdf | |
""" | |
updates = [] | |
all_grads = theano.grad(loss, all_params) | |
alpha = learning_rate | |
t = theano.shared(np.float32(1)) | |
b1_t = b1*gamma**(t-1) #(Decay the first moment running average coefficient) | |
for theta_previous, g in zip(all_params, all_grads): | |
m_previous = theano.shared(np.zeros(theta_previous.get_value().shape, | |
dtype=theano.config.floatX)) | |
v_previous = theano.shared(np.zeros(theta_previous.get_value().shape, | |
dtype=theano.config.floatX)) | |
m = b1_t*m_previous + (1 - b1_t)*g # (Update biased first moment estimate) | |
v = b2*v_previous + (1 - b2)*g**2 # (Update biased second raw moment estimate) | |
m_hat = m / (1-b1**t) # (Compute bias-corrected first moment estimate) | |
v_hat = v / (1-b2**t) # (Compute bias-corrected second raw moment estimate) | |
theta = theta_previous - (alpha * m_hat) / (T.sqrt(v_hat) + e) #(Update parameters) | |
updates.append((m_previous, m)) | |
updates.append((v_previous, v)) | |
updates.append((theta_previous, theta) ) | |
updates.append((t, t + 1.)) | |
return updates |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment