Skip to content

Instantly share code, notes, and snippets.

@aerinkim
Last active December 24, 2020 16:58
Show Gist options
  • Save aerinkim/dfe3da1000e67aced1c7d9279351cb88 to your computer and use it in GitHub Desktop.
Save aerinkim/dfe3da1000e67aced1c7d9279351cb88 to your computer and use it in GitHub Desktop.
Adam Implementation from scratch
from torch.optim import Optimizer
class ADAMOptimizer(Optimizer):
"""
implements ADAM Algorithm, as a preceding step.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(ADAMOptimizer, self).__init__(params, defaults)
def step(self):
"""
Performs a single optimization step.
"""
loss = None
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Momentum (Exponential MA of gradients)
state['exp_avg'] = torch.zeros_like(p.data)
#print(p.data.size())
# RMS Prop componenet. (Exponential MA of squared gradients). Denominator.
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
b1, b2 = group['betas']
state['step'] += 1
# L2 penalty. Gotta add to Gradient as well.
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
# Momentum
exp_avg = torch.mul(exp_avg, b1) + (1 - b1)*grad
# RMS
exp_avg_sq = torch.mul(exp_avg_sq, b2) + (1-b2)*(grad*grad)
denom = exp_avg_sq.sqrt() + group['eps']
bias_correction1 = 1 / (1 - b1 ** state['step'])
bias_correction2 = 1 / (1 - b2 ** state['step'])
adapted_learning_rate = group['lr'] * bias_correction1 / math.sqrt(bias_correction2)
p.data = p.data - adapted_learning_rate * exp_avg / denom
if state['step'] % 10000 ==0:
print ("group:", group)
print("p: ",p)
print("p.data: ", p.data) # W = p.data
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment