Skip to content

Instantly share code, notes, and snippets.

@jimmy15923
Created January 9, 2019 09:47
Show Gist options
  • Save jimmy15923/cd82a38b3ffb32c971d36d93a142f495 to your computer and use it in GitHub Desktop.
Save jimmy15923/cd82a38b3ffb32c971d36d93a142f495 to your computer and use it in GitHub Desktop.
import keras.layers as KL
import keras.engine as KE
import keras.models as KM
from keras.engine import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
class BatchNorm(KL.BatchNormalization):
"""Extends the Keras BatchNormalization class to allow a central place
to make changes if needed.
Batch normalization has a negative effect on training if batches are small
so this layer is often frozen (via setting in Config class) and functions
as linear layer.
"""
def call(self, inputs, training=None):
"""
Note about training values:
None: Train BN layers. This is the normal mode
False: Freeze BN layers. Good when batch size is small
True: (don't use). Set layer in training mode even when inferencing
"""
return super(self.__class__, self).call(inputs, training=training)
class GroupNorm(Layer):
"""Group normalization layer
Group Normalization divides the channels into groups and computes within each group
the mean and variance for normalization. GN's computation is independent of batch sizes,
and its accuracy is stable in a wide range of batch sizes
# Arguments
groups: Integer, the number of groups for Group Normalization.
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `BatchNormalization`.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Group Normalization](https://arxiv.org/abs/1803.08494)
"""
def __init__(self,
groups=32,
axis=-1,
epsilon=1e-5,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(GroupNorm, self).__init__(**kwargs)
self.supports_masking = True
self.groups = groups
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError('Axis ' + str(self.axis) + ' of '
'input tensor should have a defined dimension '
'but the layer received an input with shape ' +
str(input_shape) + '.')
if dim < self.groups:
raise ValueError('Number of groups (' + str(self.groups) + ') cannot be '
'more than the number of channels (' +
str(dim) + ').')
if dim % self.groups != 0:
raise ValueError('Number of groups (' + str(self.groups) + ') must be a '
'multiple of the number of channels (' +
str(dim) + ').')
self.input_spec = InputSpec(ndim=len(input_shape),
axes={self.axis: dim})
shape = (dim,)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, **kwargs):
input_shape = K.int_shape(inputs)
tensor_input_shape = K.shape(inputs)
# Prepare broadcasting shape.
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
broadcast_shape.insert(1, self.groups)
reshape_group_shape = K.shape(inputs)
group_axes = [reshape_group_shape[i] for i in range(len(input_shape))]
group_axes[self.axis] = input_shape[self.axis] // self.groups
group_axes.insert(1, self.groups)
# reshape inputs to new group shape
group_shape = [group_axes[0], self.groups] + group_axes[2:]
group_shape = K.stack(group_shape)
inputs = K.reshape(inputs, group_shape)
group_reduction_axes = list(range(len(group_axes)))
group_reduction_axes = group_reduction_axes[2:]
mean = K.mean(inputs, axis=group_reduction_axes, keepdims=True)
variance = K.var(inputs, axis=group_reduction_axes, keepdims=True)
inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))
# prepare broadcast shape
inputs = K.reshape(inputs, group_shape)
outputs = inputs
# In this case we must explicitly broadcast all parameters.
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
outputs = outputs * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
outputs = outputs + broadcast_beta
outputs = K.reshape(outputs, tensor_input_shape)
return outputs
def get_config(self):
config = {
'groups': self.groups,
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(GroupNorm, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
class SwitchNorm(Layer):
"""Switchable Normalization layer
Switch Normalization performs Instance Normalization, Layer Normalization and Batch
Normalization using its parameters, and then weighs them using learned parameters to
allow different levels of interaction of the 3 normalization schemes for each layer.
Only supports the moving average variant from the paper, since the `batch average`
scheme requires dynamic graph execution to compute the mean and variance of several
batches at runtime.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `BatchNormalization`.
momentum: Momentum for the moving mean and the moving variance. The original
implementation suggests a default momentum of `0.997`, however it is highly
unstable and training can fail after a few epochs. To stabilise training, use
lower values of momentum such as `0.99` or `0.98`.
epsilon: Small float added to variance to avoid dividing by zero.
final_gamma: Bool value to determine if this layer is the final
normalization layer for the residual block. Overrides the initialization
of the scaling weights to be `zeros`. Only used for Residual Networks,
to make the forward/backward signal initially propagated through an
identity shortcut.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
mean_weights_initializer: Initializer for the mean weights.
variance_weights_initializer: Initializer for the variance weights.
moving_mean_initializer: Initializer for the moving mean.
moving_variance_initializer: Initializer for the moving variance.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
mean_weights_regularizer: Optional regularizer for the mean weights.
variance_weights_regularizer: Optional regularizer for the variance weights.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
mean_weights_constraints: Optional constraint for the mean weights.
variance_weights_constraints: Optional constraint for the variance weights.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Differentiable Learning-to-Normalize via Switchable Normalization](https://arxiv.org/abs/1806.10779)
"""
def __init__(self,
axis=-1,
momentum=0.99,
epsilon=1e-3,
final_gamma=False,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
mean_weights_initializer='ones',
variance_weights_initializer='ones',
moving_mean_initializer='ones',
moving_variance_initializer='zeros',
beta_regularizer=None,
gamma_regularizer=None,
mean_weights_regularizer=None,
variance_weights_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
mean_weights_constraints=None,
variance_weights_constraints=None,
**kwargs):
super(SwitchNorm, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.momentum = momentum
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
if final_gamma:
self.gamma_initializer = initializers.get('zeros')
else:
self.gamma_initializer = initializers.get(gamma_initializer)
self.mean_weights_initializer = initializers.get(mean_weights_initializer)
self.variance_weights_initializer = initializers.get(variance_weights_initializer)
self.moving_mean_initializer = initializers.get(moving_mean_initializer)
self.moving_variance_initializer = initializers.get(moving_variance_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.mean_weights_regularizer = regularizers.get(mean_weights_regularizer)
self.variance_weights_regularizer = regularizers.get(variance_weights_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
self.mean_weights_constraints = constraints.get(mean_weights_constraints)
self.variance_weights_constraints = constraints.get(variance_weights_constraints)
def build(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError('Axis ' + str(self.axis) + ' of '
'input tensor should have a defined dimension '
'but the layer received an input with shape ' +
str(input_shape) + '.')
self.input_spec = InputSpec(ndim=len(input_shape),
axes={self.axis: dim})
shape = (dim,)
if self.scale:
self.gamma = self.add_weight(
shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(
shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.moving_mean = self.add_weight(
shape=shape,
name='moving_mean',
initializer=self.moving_mean_initializer,
trainable=False)
self.moving_variance = self.add_weight(
shape=shape,
name='moving_variance',
initializer=self.moving_variance_initializer,
trainable=False)
self.mean_weights = self.add_weight(
shape=(3,),
name='mean_weights',
initializer=self.mean_weights_initializer,
regularizer=self.mean_weights_regularizer,
constraint=self.mean_weights_constraints)
self.variance_weights = self.add_weight(
shape=(3,),
name='variance_weights',
initializer=self.variance_weights_initializer,
regularizer=self.variance_weights_regularizer,
constraint=self.variance_weights_constraints)
self.built = True
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
# Prepare broadcasting shape.
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
if self.axis != 0:
del reduction_axes[0]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis]
mean_instance = K.mean(inputs, reduction_axes, keepdims=True)
variance_instance = K.var(inputs, reduction_axes, keepdims=True)
mean_layer = K.mean(mean_instance, self.axis, keepdims=True)
temp = variance_instance + K.square(mean_instance)
variance_layer = K.mean(temp, self.axis, keepdims=True) - K.square(mean_layer)
def training_phase():
mean_batch = K.mean(mean_instance, axis=0, keepdims=True)
variance_batch = K.mean(temp, axis=0, keepdims=True) - K.square(mean_batch)
mean_batch_reshaped = K.flatten(mean_batch)
variance_batch_reshaped = K.flatten(variance_batch)
if K.backend() != 'cntk':
sample_size = K.prod([K.shape(inputs)[axis]
for axis in reduction_axes])
sample_size = K.cast(sample_size, dtype=K.dtype(inputs))
# sample variance - unbiased estimator of population variance
variance_batch_reshaped *= sample_size / (sample_size - (1.0 + self.epsilon))
self.add_update([K.moving_average_update(self.moving_mean,
mean_batch_reshaped,
self.momentum),
K.moving_average_update(self.moving_variance,
variance_batch_reshaped,
self.momentum)],
inputs)
return normalize_func(mean_batch, variance_batch)
def inference_phase():
mean_batch = self.moving_mean
variance_batch = self.moving_variance
return normalize_func(mean_batch, variance_batch)
def normalize_func(mean_batch, variance_batch):
mean_batch = K.reshape(mean_batch, broadcast_shape)
variance_batch = K.reshape(variance_batch, broadcast_shape)
mean_weights = K.softmax(self.mean_weights, axis=0)
variance_weights = K.softmax(self.variance_weights, axis=0)
mean = (mean_weights[0] * mean_instance +
mean_weights[1] * mean_layer +
mean_weights[2] * mean_batch)
variance = (variance_weights[0] * variance_instance +
variance_weights[1] * variance_layer +
variance_weights[2] * variance_batch)
outputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
outputs = outputs * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
outputs = outputs + broadcast_beta
return outputs
if training in {0, False}:
return inference_phase()
return K.in_train_phase(training_phase,
inference_phase,
training=training)
def get_config(self):
config = {
'axis': self.axis,
'epsilon': self.epsilon,
'momentum': self.momentum,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'mean_weights_initializer': initializers.serialize(self.mean_weights_initializer),
'variance_weights_initializer': initializers.serialize(self.variance_weights_initializer),
'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer),
'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'mean_weights_regularizer': regularizers.serialize(self.mean_weights_regularizer),
'variance_weights_regularizer': regularizers.serialize(self.variance_weights_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint),
'mean_weights_constraints': constraints.serialize(self.mean_weights_constraints),
'variance_weights_constraints': constraints.serialize(self.variance_weights_constraints),
}
base_config = super(SwitchNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
get_custom_objects().update({'GroupNorm': GroupNorm})
get_custom_objects().update({'SwitchNorm': SwitchNorm})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment