Created
April 29, 2017 14:04
-
-
Save ajbrock/075c0ca4036dc4d8581990a6e76e07a3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from torch.optim.optimizer import Optimizer | |
# This version of Adam keeps an fp32 copy of the parameters and | |
# does all of the parameter updates in fp32, while still doing the | |
# forwards and backwards passes using fp16 (i.e. fp16 copies of the | |
# parameters and fp16 activations). | |
# | |
# Note that this calls .float().cuda() on the params such that it | |
# moves them to gpu 0--if you're using a different GPU or want to | |
# do multi-GPU you may need to deal with this. | |
class Adam16(Optimizer): | |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | |
weight_decay=0): | |
defaults = dict(lr=lr, betas=betas, eps=eps, | |
weight_decay=weight_decay) | |
params = list(params) | |
super(Adam16, self).__init__(params, defaults) | |
# for group in self.param_groups: | |
# for p in group['params']: | |
self.fp32_param_groups = [p.data.float().cuda() for p in params] | |
if not isinstance(self.fp32_param_groups[0], dict): | |
self.fp32_param_groups = [{'params': self.fp32_param_groups}] | |
def step(self, closure=None): | |
"""Performs a single optimization step. | |
Arguments: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group,fp32_group in zip(self.param_groups,self.fp32_param_groups): | |
for p,fp32_p in zip(group['params'],fp32_group['params']): | |
if p.grad is None: | |
continue | |
grad = p.grad.data.float() | |
state = self.state[p] | |
# State initialization | |
if len(state) == 0: | |
state['step'] = 0 | |
# Exponential moving average of gradient values | |
state['exp_avg'] = grad.new().resize_as_(grad).zero_() | |
# Exponential moving average of squared gradient values | |
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() | |
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |
beta1, beta2 = group['betas'] | |
state['step'] += 1 | |
if group['weight_decay'] != 0: | |
grad = grad.add(group['weight_decay'], fp32_p) | |
# Decay the first and second moment running average coefficient | |
exp_avg.mul_(beta1).add_(1 - beta1, grad) | |
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
denom = exp_avg_sq.sqrt().add_(group['eps']) | |
bias_correction1 = 1 - beta1 ** state['step'] | |
bias_correction2 = 1 - beta2 ** state['step'] | |
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 | |
# print(type(fp32_p)) | |
fp32_p.addcdiv_(-step_size, exp_avg, denom) | |
p.data = fp32_p.half() | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment