Skip to content

Instantly share code, notes, and snippets.

@finbarrtimbers
Last active March 9, 2025 18:40
Show Gist options
  • Save finbarrtimbers/98a03be83a8953a461f8b1d8716feebc to your computer and use it in GitHub Desktop.
Save finbarrtimbers/98a03be83a8953a461f8b1d8716feebc to your computer and use it in GitHub Desktop.
Adam
class SimpleAdam(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
super().__init__(params, defaults={'lr': lr})
self.state = {}
self.t = 0
self.betas = betas
self.eps = eps
for group in self.param_groups:
for p in group['params']:
self.state[p] = {
'first_moment': torch.zeros_like(p.data),
'second_moment': torch.zeros_like(p.data),
}
# Step Method
def step(self):
self.t += 1
for group in self.param_groups:
for p in group['params']:
assert p in self.state, f"{p} not in state"
first_moment = self.state[p]['first_moment']
second_moment = self.state[p]['second_moment']
first_moment = self.betas[0] * first_moment + (1 - self.betas[0]) * p.grad.data
second_moment = self.betas[1] * second_moment + (1 - self.betas[1]) * (p.grad.data ** 2)
self.state[p]['first_moment'] = first_moment
self.state[p]['second_moment'] = second_moment
first_moment_corrected = first_moment / (1 - self.betas[0] ** self.t)
second_moment_corrected = second_moment / (1 - self.betas[1] ** self.t)
p.data -= group['lr'] * first_moment_corrected / (second_moment_corrected.sqrt() + self.eps)
class SimpleAdamW(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay: float = 1e-5):
super().__init__(params, defaults={'lr': lr})
self.state = {}
self.t = 0
self.betas = betas
self.eps = eps
self.weight_decay = weight_decay
for group in self.param_groups:
for p in group['params']:
self.state[p] = {
'first_moment': torch.zeros_like(p.data),
'second_moment': torch.zeros_like(p.data),
}
# Step Method
def step(self):
self.t += 1
for group in self.param_groups:
for p in group['params']:
assert p in self.state, f"{p} not in state"
first_moment = self.state[p]['first_moment']
second_moment = self.state[p]['second_moment']
first_moment = self.betas[0] * first_moment + (1 - self.betas[0]) * p.grad.data
second_moment = self.betas[1] * second_moment + (1 - self.betas[1]) * (p.grad.data ** 2)
self.state[p]['first_moment'] = first_moment
self.state[p]['second_moment'] = second_moment
first_moment_corrected = first_moment / (1 - self.betas[0] ** self.t)
second_moment_corrected = second_moment / (1 - self.betas[1] ** self.t)
p.data -= group['lr'] * self.weight_decay * p.data
p.data -= group['lr'] * first_moment_corrected / (second_moment_corrected.sqrt() + self.eps)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment