Created
March 16, 2018 11:32
-
-
Save kashif/ecbe62c34a026b7d10f3312d0300a29d to your computer and use it in GitHub Desktop.
AccSGD optimizer for keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AccSGD(Optimizer): | |
"""AccSGD optimizer. | |
Arguments: | |
lr (float): learning rate | |
kappa (float, optional): ratio of long to short step (default: 1000) | |
xi (float, optional): statistical advantage parameter (default: 10) | |
smallConst (float, optional): any value <=1 (default: 0.7) | |
# References | |
- [Accelerating Stochastic Gradient Descent](https://arxiv.org/abs/1704.08227) | |
""" | |
def __init__(self, lr=0.1, kappa=1000.0, xi=10.0, smallConst=0.7, **kwargs): | |
super(AccSGD, self).__init__(**kwargs) | |
with K.name_scope(self.__class__.__name__): | |
self.lr = K.variable(lr, name='lr') | |
self.kappa = K.variable(kappa, name='kappa') | |
self.xi = K.variable(xi, name='xi') | |
self.smallConst = K.variable(smallConst, name='smallConst') | |
@interfaces.legacy_get_updates_support | |
def get_updates(self, loss, params): | |
grads = self.get_gradients(loss, params) | |
large_lr = (self.lr*self.kappa)/self.smallConst | |
beta = (self.smallConst*self.smallConst*self.xi)/self.kappa | |
alpha = 1.0 - beta | |
zeta = self.smallConst/(self.smallConst+beta) | |
ms = [K.variable(K.identity(p), dtype=K.dtype(p)) for p in params] | |
for p, g, m in zip(params, grads, ms): | |
m_t = alpha * m + beta * (p - large_lr*g) | |
p_t = zeta*(p - self.lr * g) + (beta/(self.smallConst + beta))*m_t | |
self.updates.append(K.update(m, m_t)) | |
new_p = p_t | |
# Apply constraints. | |
if getattr(p, 'constraint', None) is not None: | |
new_p = p.constraint(new_p) | |
self.updates.append(K.update(p, new_p)) | |
return self.updates | |
def get_config(self): | |
config = {'lr': float(K.get_value(self.lr)), | |
'kappa': float(K.get_value(self.kappa)), | |
'smallConst': float(K.get_value(self.smallConst)), | |
'xi': float(K.get_value(self.xi))} | |
base_config = super(AccSGD, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment