Created
January 17, 2018 10:08
-
-
Save mkocabas/99658da8186145f6f1e2fc70e882dac0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.optimizers import Optimizer | |
from keras import backend as K | |
from keras.legacy import interfaces | |
class MultiSGD(Optimizer): | |
""" | |
Modified SGD with added support for learning multiplier for kernels and biases | |
as suggested in: https://github.com/fchollet/keras/issues/5920 | |
Stochastic gradient descent optimizer. | |
Includes support for momentum, | |
learning rate decay, and Nesterov momentum. | |
# Arguments | |
lr: float >= 0. Learning rate. | |
momentum: float >= 0. Parameter updates momentum. | |
decay: float >= 0. Learning rate decay over each update. | |
nesterov: boolean. Whether to apply Nesterov momentum. | |
""" | |
def __init__(self, lr=0.01, momentum=0., decay=0., | |
nesterov=False, lr_mult=None, **kwargs): | |
super(MultiSGD, self).__init__(**kwargs) | |
with K.name_scope(self.__class__.__name__): | |
self.iterations = K.variable(0, dtype='int64', name='iterations') | |
self.lr = K.variable(lr, name='lr') | |
self.momentum = K.variable(momentum, name='momentum') | |
self.decay = K.variable(decay, name='decay') | |
self.initial_decay = decay | |
self.nesterov = nesterov | |
self.lr_mult = lr_mult | |
@interfaces.legacy_get_updates_support | |
def get_updates(self, loss, params): | |
grads = self.get_gradients(loss, params) | |
self.updates = [K.update_add(self.iterations, 1)] | |
lr = self.lr | |
if self.initial_decay > 0: | |
lr *= (1. / (1. + self.decay * K.cast(self.iterations, | |
K.dtype(self.decay)))) | |
# momentum | |
shapes = [K.int_shape(p) for p in params] | |
moments = [K.zeros(shape) for shape in shapes] | |
self.weights = [self.iterations] + moments | |
for p, g, m in zip(params, grads, moments): | |
if p.name in self.lr_mult: | |
multiplied_lr = lr * self.lr_mult[p.name] | |
else: | |
multiplied_lr = lr | |
v = self.momentum * m - multiplied_lr * g # velocity | |
self.updates.append(K.update(m, v)) | |
if self.nesterov: | |
new_p = p + self.momentum * v - multiplied_lr * g | |
else: | |
new_p = p + v | |
# Apply constraints. | |
if getattr(p, 'constraint', None) is not None: | |
new_p = p.constraint(new_p) | |
self.updates.append(K.update(p, new_p)) | |
return self.updates | |
def get_config(self): | |
config = {'lr': float(K.get_value(self.lr)), | |
'momentum': float(K.get_value(self.momentum)), | |
'decay': float(K.get_value(self.decay)), | |
'nesterov': self.nesterov} | |
base_config = super(MultiSGD, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create a lr rate multiplier dictionary | |
# Here is sample to define lr multiplier for convolutional layers that start with 'conv2d' prefix | |
lr_mult=dict() | |
for layer in model.layers: | |
if isinstance(layer, Conv2D): | |
if re.match("conv2d.*", layer.name): | |
kernel_name = layer.weights[0].name | |
bias_name = layer.weights[1].name | |
lr_mult[kernel_name] = 1 | |
lr_mult[bias_name] = 2 | |
multisgd = MultiSGD(lr=1e-2, momentum=0.9, decay=0.0, nesterov=False, lr_mult=lr_mult) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment