Last active
November 6, 2022 19:12
-
-
Save MokkeMeguru/35af0c7ddba511f6a268e7c78fdba2d6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" MY Batchnormalization layer | |
ref & thanks. | |
https://github.com/tensorflow/tensorflow/issues/18222 | |
https://github.com/jkyl/biggan-deep/blob/master/src/custom_layers/batch_normalization.py | |
https://github.com/tensorflow/community/blob/master/rfcs/20181016-replicator.md#global-batch-normalization | |
https://github.com/Apm5/tensorflow_2.0_tutorial/blob/master/CNN/BatchNormalization.py | |
WARN: This layer cannot accept the variable fused=True, | |
We need select fused=False at constructor. | |
IF you select fused=True|None, this layer attributes as same as official BatchNormalization. | |
""" | |
from typing import Callable, List, Tuple, Union | |
import tensorflow as tf | |
from tensorflow import distribute, dtypes | |
from tensorflow.keras import constraints, initializers, layers, regularizers | |
from tensorflow.keras.layers import InputSpec | |
from tensorflow.python.distribute import distribution_strategy_context | |
from tensorflow.python.keras import backend as K | |
from tensorflow.python.keras.engine import base_layer_utils | |
from tensorflow.python.keras.utils.tf_utils import constant_value, smart_cond | |
from tensorflow.python.ops.nn import fused_batch_norm | |
from tensorflow.python.platform import tf_logging as logging | |
# distribution_strategy_context, K, smart_cond, constant_value, | |
# base_layer_utils, logging, fused_batch_norm etc... | |
# | |
# are hidden values. | |
# So, this layer will be unusable. | |
Initializer = initializers.Initializer | |
Regularizer = regularizers.Regularizer | |
Constraint = constraints.Constraint | |
Layer = layers.Layer | |
Strategy = distribute.Strategy | |
def _compose_transforms( | |
scale: tf.Tensor, offset: tf.Tensor, | |
then_scale, then_offset): | |
if then_scale is not None: | |
scale *= then_scale | |
offset *= then_scale | |
if then_offset is not None: | |
offset += then_offset | |
return (scale, offset) | |
class SyncBatchNormalization(Layer): | |
""" | |
args: | |
- axis: Union[int, list] | |
- momentum: float | |
- epsilon: float | |
- center: bool | |
- scale: bool | |
- beta_initializer: Union[str, Initializer] | |
- gamma_initializer: Union[str, Initializer] | |
- moving_mean_initializer: Union[str, Initializer] | |
- moving_variance_initializer: Union[str, Initializer] | |
- beta_regularizer: Union[str, Regularizer] | |
- gamma_regularizer: Union[str, Regularizer] | |
- beta_constraint: Union[str, Constraint] | |
- gamma_constraint: Union[str, Constraint] | |
- renorm: bool | |
- renorm_clipping | |
- renorm_momentum: float | |
- fused: Union[bool, None] | |
default False (if true or None, sync batchnorm will be disable) | |
- trainable: bool | |
- virtual_batch_size: int | |
- adjustment | |
- name: str | |
- **kwargs | |
""" | |
_USE_V2_BEHAVIOR = True | |
def __init__(self, | |
axis: Union[int, list] = -1, | |
momentum: float = 0.99, | |
epsilon: float = 1e-3, | |
center: bool = True, | |
scale: bool = True, | |
beta_initializer: Union[str, Initializer] = 'zeros', | |
gamma_initializer: Union[str, Initializer] = 'ones', | |
moving_mean_initializer: Union[str, Initializer] = 'zeros', | |
moving_variance_initializer: Union[str, Initializer] = 'ones', | |
beta_regularizer: Union[str, Regularizer] = None, | |
gamma_regularizer: Union[str, Regularizer] = None, | |
beta_constraint: Union[str, Constraint] = None, | |
gamma_constraint: Union[str, Constraint] = None, | |
renorm: bool = False, | |
renorm_clipping=None, | |
renorm_momentum: float = 0.99, | |
fused: Union[bool, None] = None, | |
trainable: bool = True, | |
virtual_batch_size: int = None, | |
adjustment=None, | |
name: str = None, | |
**kwargs): | |
super(SyncBatchNormalization, self).__init__( | |
name=name, | |
**kwargs) | |
if isinstance(axis, (list, tuple)): | |
self.axis = axis[:] | |
elif isinstance(axis, int): | |
self.axis = axis | |
else: | |
raise TypeError('Excepted an int or list/tuple of ints for the' | |
'argument \'axis\', but received: {}'.format(axis)) | |
self.momentum = momentum | |
self.epsilon = epsilon | |
self.center = center | |
self.scale = scale | |
self.beta_initializer = initializers.get(beta_initializer) | |
self.gamma_initializer = initializers.get(gamma_initializer) | |
self.moving_mean_initializer = initializers.get( | |
moving_mean_initializer) | |
self.moving_variance_initializer = initializers.get( | |
moving_variance_initializer) | |
self.beta_regularizer = regularizers.get(beta_regularizer) | |
self.gamma_regularizer = regularizers.get(gamma_regularizer) | |
self.beta_constraint = constraints.get(beta_constraint) | |
self.gamma_constraint = constraints.get(gamma_constraint) | |
self.renorm = renorm | |
self.virtual_batch_size = virtual_batch_size | |
self.adjustment = adjustment | |
if self._USE_V2_BEHAVIOR: | |
if fused: | |
self._raise_if_fused_cannot_be_used() | |
# We leave fused as None if self._fused_can_be_used()==True, | |
# since we still may set it to False in self.build() | |
# if the input rank is not 4. | |
elif fused is None and not self._fused_can_be_used(): | |
fused = False | |
elif fused is None: | |
fused = True | |
self.supports_masking = True | |
self.fused = fused | |
self._bessels_correction_test_only = True | |
self._trainable_var = None | |
self.trainable = trainable | |
if renorm: | |
renorm_clipping = renorm_clipping or {} | |
keys = ['rmax', 'rmin', 'dmax'] | |
if set(renorm_clipping) - set(keys): | |
raise ValueError('renorm_clipping %s contains keys not in %s' % | |
(renorm_clipping, keys)) | |
self.renorm_clipping = renorm_clipping | |
self.renorm_momentum = renorm_momentum | |
def build(self, input_shape: Union[list, tf.TensorShape]): | |
""" | |
args: | |
- input_shape: list | |
example. [None, H, W, C] = [None, 32, 32, 3] (cifer 10) | |
""" | |
input_shape = tf.TensorShape(input_shape) | |
if not input_shape.ndims: | |
raise ValueError('Input has undefined rank:', input_shape) | |
ndims = len(input_shape) | |
# Convert axis to list and resolve negatives | |
if isinstance(self.axis, int): | |
self.axis = [self.axis] | |
for idx, x in enumerate(self.axis): | |
if x < 0: | |
self.axis[idx] = ndims + x | |
# Validate axes | |
self._validate_axis(ndims) | |
self._validate_virtual_batch_size() | |
self._validate_fused(ndims) | |
self._setup_fuse() | |
param_shape = self._validate_dim(ndims, input_shape) | |
if self.scale: | |
self.gamma = self.add_weight( | |
name='gamma', | |
shape=param_shape, | |
dtype=self._param_dtype, | |
initializer=self.gamma_initializer, | |
regularizer=self.gamma_regularizer, | |
constraint=self.gamma_constraint, | |
trainable=True, | |
experimental_autocast=False) | |
else: | |
self.gamma = None | |
if self.fused: | |
self._gamma_const = K.constant( | |
1.0, dtype=self._param_dtype, shape=param_shape) | |
if self.center: | |
self.beta = self.add_weight( | |
name='beta', | |
shape=param_shape, | |
dtype=self._param_dtype, | |
initializer=self.beta_initializer, | |
regularizer=self.beta_regularizer, | |
constraint=self.beta_constraint, | |
trainable=True, | |
experimental_autocast=False) | |
else: | |
self.beta = None | |
if self.fused: | |
self._beta_const = K.constant( | |
0.0, dtype=self._param_dtype, shape=param_shape) | |
self._setup_movings(param_shape) | |
self.built = True | |
def call(self, inputs: tf.Tensor, training=None, **kwargs): | |
training = self._get_training_value(training) | |
if self.fused: | |
return self.fused_normalize(inputs, training, **kwargs) | |
else: | |
return self.normalize(inputs, training, **kwargs) | |
def normalize(self, inputs: tf.Tensor, training): | |
inputs, undo_virtual_batching = self._gen_virtual_batch(inputs) | |
# Compute the axes along which to reduce the mean / variance | |
input_shape = inputs.shape | |
ndims = len(input_shape) | |
reduction_axes = [i for i in range(ndims) if i not in self.axis] | |
if self.virtual_batch_size is not None: | |
del reduction_axes[1] # Do not reduce along virtual batch dim | |
# Broadcasting only necessary for single-axis batch norm | |
# where the axis is not the last dimension | |
broadcast_shape = [1] * ndims | |
broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value | |
_broadcast = self._gen_broadcast( | |
ndims, reduction_axes, broadcast_shape) | |
scale, offset = _broadcast(self.gamma), _broadcast(self.beta) | |
training_value = constant_value(training) | |
if training_value is False: | |
mean, variance = self.untraining_stats() | |
else: | |
mean, variance = self.training_stats( | |
inputs, | |
scale, | |
offset, | |
reduction_axes, | |
training, | |
_broadcast | |
) | |
if offset is not None: | |
offset = tf.cast(offset, inputs.dtype) | |
if scale is not None: | |
scale = tf.cast(scale, inputs.dtype) | |
# TODO(reedwm): Maybe do math in float32 | |
# if given float16 inputs, if doing | |
# math in float16 hurts validation accuracy of | |
# popular models like resnet. | |
outputs = tf.nn.batch_normalization( | |
inputs, _broadcast(mean), _broadcast(variance), | |
offset, scale, self.epsilon) | |
# If some components of the shape got lost | |
# due to adjustments, fix that. | |
outputs.set_shape(input_shape) | |
if self.virtual_batch_size is not None: | |
outputs = undo_virtual_batching(outputs) | |
return outputs | |
def fused_normalize(self, inputs: tf.Tensor, training, **kwargs): | |
inputs, undo_virtual_batching = self._gen_virtual_batch(inputs) | |
outputs = self._fused_batch_norm(inputs, training=training) | |
if self.virtual_batch_size is not None: | |
# Currently never reaches here | |
# since fused_batch_norm does not support virtual batching | |
outputs = undo_virtual_batching(outputs) | |
return outputs | |
@property | |
def trainable(self) -> bool: | |
return self._trainable | |
@trainable.setter | |
def trainable(self, value: bool): | |
self._trainable = value | |
if self._trainable_var is not None: | |
self._trainable_var.update_value(value) | |
def _get_trainable_var(self) -> tf.Tensor: | |
if self._trainable_var is None: | |
self._trainable_var = K.freezable_variable( | |
self._trainable, name=self.name + '_trainable') | |
return self._trainable_var | |
def _raise_if_fused_cannot_be_used(self): | |
"""Raises a ValueError if fused implementation cannot be used. | |
In addition to the checks done in this function, | |
the input tensors rank must be 4. | |
The input rank check can only be done once the input shape is known. | |
""" | |
# Currently fused batch norm doesn't support renorm. | |
# It also only supports a channel dimension on axis 1 or 3, | |
# when no virtual batch size or adjustment is used. | |
if self.renorm: | |
raise ValueError('Passing both fused=True and renorm=True is ' | |
'unsupported') | |
axis = [self.axis] if isinstance(self.axis, int) else self.axis | |
# Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, | |
# because the input rank is required to be 4 (which is checked later). | |
if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3): | |
raise ValueError('Passing fused=True is only supported ' | |
'when axis is 1 or 3') | |
if self.virtual_batch_size is not None: | |
raise ValueError('Passing fused=True is unsupported when ' | |
'virtual_batch_size is specified.') | |
if self.adjustment is not None: | |
raise ValueError('Passing fused=True is unsupported when ' | |
'adjustment is specified.') | |
def _fused_can_be_used(self): | |
try: | |
self._raise_if_fused_cannot_be_used() | |
return True | |
except ValueError: | |
return False | |
@property | |
def _param_dtype(self) -> dtypes.DType: | |
# Raise parameters of fp16 batch norm to fp32 | |
if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16: | |
return dtypes.float32 | |
else: | |
return self.dtype or dtypes.float32 | |
def _support_zero_size_input(self) -> Strategy: | |
return distribution_strategy_context.has_strategy() and getattr( | |
distribution_strategy_context.get_strategy().extended, | |
'experimental_enable_get_next_as_optional', False) | |
def _validate_axis(self, ndims: int): | |
for x in self.axis: | |
if x < 0 or x >= ndims: | |
raise ValueError('Invalid axis: {}'.format(x)) | |
if len(self.axis) != len(set(self.axis)): | |
raise ValueError('Duplicate axis: {}'.format(self.axis)) | |
def _validate_virtual_batch_size(self): | |
if self.virtual_batch_size is not None: | |
if self.virtual_batch_size <= 0: | |
raise ValueError( | |
'virtual_batch_size must be a positive integer that ' | |
'devides the true batch siz of the input Tensor') | |
# If using virtual batches, the first dimension must be the batch | |
# dimension and cannot be the batch norm axis | |
if 0 in self.axis: | |
raise ValueError( | |
'When using virtual_batch_size, the batch dimension ' | |
'must be 0 and thus axis cannot include 0') | |
if self.adjustment is not None: | |
raise ValueError( | |
'When using virtual_batch_size, adjustment cannot ' | |
'be specified') | |
def _validate_fused(self, ndims: int): | |
if self.fused in (None, True): | |
# TODO(yaozhang): if input is not 4D, reshape it to 4D and | |
# reshape the output back to its original shape accordingly. | |
if self._USE_V2_BEHAVIOR: | |
if self.fused is None: | |
self.fused = (ndims == 4) | |
elif self.fused and ndims != 4: | |
raise ValueError( | |
'Batch normalization layers with fused=True only ' | |
'support 4D input tensors.') | |
else: | |
# TODO(chrisying): fused batch norm is currently not supported for | |
# multi-axis batch norm and by extension virtual batches. | |
# In some cases, it might be possible to use fused batch norm | |
# but would require reshaping | |
# the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is | |
# particularly tricky. A compromise might be to just support | |
# the most common use case (turning 5D w/ virtual batch to NCHW) | |
# self.feature_dim = input_shape[self.axis] | |
assert self.fused is not None | |
self.fused = (ndims == 4 and self._fused_can_be_used()) | |
def _setup_fuse(self): | |
if self.fused: | |
if self.axis == [1]: | |
self._data_format = 'NCHW' | |
elif self.axis == [3]: | |
self._data_format = 'NHWC' | |
else: | |
raise ValueError( | |
'Unsupported axis, fused batch norm only supports ' | |
'axis == [1] or axis == [3]') | |
def _validate_dim(self, ndims: int, input_shape: tf.TensorShape) -> List: | |
axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} | |
# Note: | |
# {x: [None, 32, 32, 1][x] for x in [3]} | |
# => {3: 1} | |
for x in axis_to_dim: | |
if axis_to_dim[x] is None: | |
raise ValueError( | |
'Input has undefined `axis` dimension. Input shape: ', | |
input_shape) | |
self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) | |
if len(axis_to_dim) == 1 and self.virtual_batch_size is None: | |
# Single axis batch norm (most common/default use-case) | |
param_shape = (list(axis_to_dim.values())[0],) | |
else: | |
# Parameter shape is the original shape | |
# but with 1 in all non-axis dims | |
param_shape = [axis_to_dim[i] if i in axis_to_dim | |
else 1 for i in range(ndims)] | |
if self.virtual_batch_size is not None: | |
# When using virtual batches, add an extra dim at index 1 | |
param_shape.insert(1, 1) | |
for idx, x in enumerate(self.axis): | |
# Account for added dimension | |
self.axis[idx] = x + 1 | |
return param_shape | |
def _setup_movings(self, param_shape): | |
try: | |
# Disable variable partitioning | |
# when creating the moving mean and variance | |
if hasattr(self, '_scope') and self._scope: | |
partitioner = self._scope.partitioner | |
self._scope.set_partitioner(None) | |
else: | |
partitioner = None | |
self.moving_mean = self.add_weight( | |
name='moving_mean', | |
shape=param_shape, | |
dtype=self._param_dtype, | |
initializer=self.moving_mean_initializer, | |
synchronization=tf.VariableSynchronization.ON_READ, | |
trainable=False, | |
aggregation=tf.VariableAggregation.MEAN, | |
experimental_autocast=False) | |
self.moving_variance = self.add_weight( | |
name='moving_variance', | |
shape=param_shape, | |
dtype=self._param_dtype, | |
initializer=self.moving_variance_initializer, | |
synchronization=tf.VariableSynchronization.ON_READ, | |
trainable=False, | |
aggregation=tf.VariableAggregation.MEAN, | |
experimental_autocast=False) | |
if self.renorm: | |
# In batch renormalization we track the inference moving stddev | |
# instead of the moving variance | |
# to more closely align with the paper. | |
def moving_stddev_initializer(*args, **kwargs): | |
return tf.math.sqrt( | |
self.moving_variance_initializer(*args, **kwargs)) | |
with distribution_strategy_context.get_strategy( | |
).extended.colocate_vars_with(self.moving_variance): | |
self.moving_stddev = self.add_weight( | |
name='moving_stddev', | |
shape=param_shape, | |
dtype=self._param_dtype, | |
initializer=moving_stddev_initializer, | |
synchronization=tf.VariableSynchronization.ON_READ, | |
trainable=False, | |
aggregation=tf.VariableAggregation.MEAN, | |
experimental_autocast=False) | |
# Create variables to maintain the moving mean and | |
# standard deviation. | |
# These are used in training and thus are different from | |
# the moving averages above. The renorm variables are colocated | |
# with moving_mean and moving_stddev. | |
# NOTE: below, the outer `with device` block causes | |
# the current devicse tack to be cleared. | |
# The nested ones use a `lambda` to set the desired device | |
# and ignore any devices that may be set by the custom getter. | |
def _renorm_variable(name, | |
shape, | |
initializer=tf.zeros_initializer()): | |
"""Create a renorm variable.""" | |
var = self.add_weight( | |
name=name, | |
shape=shape, | |
dtype=self._param_dtype, | |
initializer=initializer, | |
synchronization=tf.VariableSynchronization.ON_READ, | |
trainable=False, | |
aggregation=tf.VariableAggregation.MEAN, | |
experimental_autocast=False) | |
return var | |
with distribution_strategy_context.get_strategy( | |
).extended.colocate_vars_with(self.moving_mean): | |
self.renorm_mean = _renorm_variable( | |
'renorm_mean', param_shape, | |
self.moving_mean_initializer) | |
with distribution_strategy_context.get_strategy( | |
).extended.colocate_vars_with(self.moving_stddev): | |
self.renorm_stddev = _renorm_variable( | |
'renorm_stddev', param_shape, | |
moving_stddev_initializer) | |
finally: | |
if partitioner: | |
self._scope.set_partitioner(partitioner) | |
def _assign_moving_average(self, | |
variable: tf.Tensor, | |
value: tf.Tensor, | |
momentum: float, | |
inputs_size: int): | |
with K.name_scope('AssignMovingAvg') as scope: | |
decay = tf.convert_to_tensor(1.0 - momentum, name='decay') | |
if decay.dtype != variable.dtype.base_dtype: | |
decay = tf.cast(decay, variable.dtype.base_dtype) | |
update_delta = ( | |
variable - tf.cast(value, variable.dtype)) * decay | |
if inputs_size is not None: | |
update_delta = tf.where( | |
inputs_size > 0, update_delta, | |
K.zeros_like(update_delta)) | |
return variable.assign_sub(update_delta, name=scope) | |
def _assign_new_value(self, variable: tf.Tensor, value: tf.Tensor): | |
with K.name_scope('AssignNewValue') as scope: | |
return variable.assign(value, name=scope) | |
def _fused_batch_norm(self, inputs, training): | |
"""Returns the output of fused batch norm.""" | |
beta = self.beta if self.center else self._beta_const | |
gamma = self.gamma if self.scale else self._gamma_const | |
# TODO(b/129279393): Support zero batch input in non | |
# DistributionStrategy code as well. | |
if self._support_zero_size_input(): | |
inputs_size = tf.size(inputs) | |
else: | |
inputs_size = None | |
def _fused_batch_norm_training(): | |
return fused_batch_norm( | |
inputs, | |
gamma, | |
beta, | |
epsilon=self.epsilon, | |
data_format=self._data_format) | |
def _fused_batch_norm_inference(): | |
return fused_batch_norm( | |
inputs, | |
gamma, | |
beta, | |
mean=self.moving_mean, | |
variance=self.moving_variance, | |
epsilon=self.epsilon, | |
is_training=False, | |
data_format=self._data_format) | |
output, mean, variance = smart_cond( | |
training, _fused_batch_norm_training, _fused_batch_norm_inference) | |
if not self._bessels_correction_test_only: | |
# Remove Bessel's correction to be consistent | |
# with non-fused batch norm. | |
# Note that the variance computed by fused batch norm is | |
# with Bessel's correction. | |
sample_size = tf.math.cast( | |
tf.size(inputs) / tf.size(variance), variance.dtype) | |
factor = ((sample_size - tf.math.cast(1.0, variance.dtype)) | |
/ sample_size) | |
variance *= factor | |
training_value = constant_value(training) | |
if training_value is None: | |
momentum = smart_cond(training, | |
lambda: self.momentum, | |
lambda: 1.0) | |
else: | |
momentum = tf.convert_to_tensor(self.momentum) | |
if training_value or training_value is None: | |
def mean_update(): | |
return self._assign_moving_average( | |
self.moving_mean, | |
mean, | |
momentum, | |
inputs_size) | |
def variance_update(): | |
"""Update self.moving_variance with | |
the most recent data point.""" | |
if self.renorm: | |
# We apply epsilon as part of the moving_stddev to mirror | |
# the training code path. | |
moving_stddev = self._assign_moving_average( | |
self.moving_stddev, | |
tf.math.sqrt(variance + self.epsilon), | |
momentum, inputs_size) | |
return self._assign_new_value( | |
self.moving_variance, | |
# Apply relu in case floating point rounding causes | |
# it to go negative. | |
K.relu(moving_stddev * moving_stddev - self.epsilon)) | |
else: | |
return self._assign_moving_average( | |
self.moving_variance, | |
variance, | |
momentum, inputs_size) | |
self.add_update(mean_update) | |
self.add_update(variance_update) | |
return output | |
def _renorm_correction_and_moments( | |
self, | |
mean: tf.Tensor, | |
variance: tf.Tensor, | |
training: bool, | |
inputs_size: int): | |
"""Returns the correction and update values for renorm.""" | |
stddev = tf.math.sqrt(variance + self.epsilon) | |
# Compute the average mean and standard deviation, as if they were | |
# initialized with this batch's moments. | |
renorm_mean = self.renorm_mean | |
# Avoid divide by zero early on in training. | |
renorm_stddev = tf.math.maximum(self.renorm_stddev, | |
tf.math.sqrt(self.epsilon)) | |
# Compute the corrections for batch renorm. | |
r = stddev / renorm_stddev | |
d = (mean - renorm_mean) / renorm_stddev | |
# Ensure the corrections use pre-update moving averages. | |
with tf.control_dependencies([r, d]): | |
mean = tf.identity(mean) | |
stddev = tf.identity(stddev) | |
rmin, rmax, dmax = [self.renorm_clipping.get(key) | |
for key in ['rmin', 'rmax', 'dmax']] | |
if rmin is not None: | |
r = tf.math.maximum(r, rmin) | |
if rmax is not None: | |
r = tf.math.minimum(r, rmax) | |
if dmax is not None: | |
d = tf.math.maximum(d, -dmax) | |
d = tf.math.minimum(d, dmax) | |
# When not training, use r=1, d=0. | |
r = smart_cond( | |
training, | |
lambda: r, | |
lambda: tf.ones_like(r)) | |
d = smart_cond( | |
training, | |
lambda: d, | |
lambda: tf.zeros_like(d)) | |
def _update_renorm_variable( | |
var: tf.Tensor, value: tf.Tensor, inputs_size: int): | |
"""Updates a moving average and weight, | |
returns the unbiased value.""" | |
value = tf.identity(value) | |
def _do_update(): | |
"""Updates the var, returns the updated value.""" | |
new_var = self._assign_moving_average( | |
var, value, self.renorm_momentum, inputs_size) | |
return new_var | |
def _fake_update(): | |
return tf.identity(var) | |
return smart_cond(training, _do_update, _fake_update) | |
# TODO(yuefengz): colocate the operations | |
update_new_mean = _update_renorm_variable( | |
self.renorm_mean, mean, inputs_size) | |
update_new_stddev = _update_renorm_variable( | |
self.renorm_stddev, stddev, inputs_size) | |
# Update the inference mode moving averages with the batch value. | |
with tf.control_dependencies([update_new_mean, update_new_stddev]): | |
out_mean = tf.identity(mean) | |
out_variance = tf.identity(variance) | |
return (r, d, out_mean, out_variance) | |
def _moments(self, | |
inputs: tf.Tensor, | |
reduction_axes: list, keep_dims: bool): | |
# not syncbatch | |
# mean, variance = tf.nn.moments( | |
# inputs, | |
# reduction_axes, | |
# keep_dims=keep_dims) | |
# @MokkeMeguru add syncbatch | |
ctx = tf.distribute.get_replica_context() | |
n = ctx.num_replicas_in_sync | |
mean, mean_sq = ctx.all_reduce( | |
tf.distribute.ReduceOp.SUM, | |
[tf.reduce_mean( | |
inputs, | |
axis=reduction_axes, keepdims=keep_dims) / n, | |
tf.reduce_mean( | |
tf.square(inputs), | |
axis=reduction_axes, keepdims=keep_dims) / n] | |
) | |
variance = mean_sq - mean ** 2 | |
# TODO(b/129279393): Support zero batch input in non | |
# DistributionStrategy code as well. | |
if self._support_zero_size_input(): | |
inputs_size = tf.size(inputs) | |
mean = tf.where(inputs_size > 0, mean, K.zeros_like(mean)) | |
variance = tf.where(inputs_size > 0, variance, | |
K.zeros_like(variance)) | |
return mean, variance | |
def _get_training_value(self, | |
training: Union[bool, None] = None) -> bool: | |
if training is None: | |
training = K.learning_phase() | |
tf.print( | |
'keras says:' , training) | |
if self._USE_V2_BEHAVIOR: | |
if isinstance(training, int): | |
training = bool(training) | |
if base_layer_utils.is_in_keras_graph(): | |
training = tf.logical_and( | |
training, self._get_trainable_var()) | |
else: | |
training = tf.logical_and(training, self.trainable) | |
tf.print( | |
'at last: ', training) | |
return training | |
def _gen_broadcast( | |
self, | |
ndims: int, | |
reduction_axes: list, | |
broadcast_shape: list) -> tf.Tensor: | |
"""example: | |
if NCHW, | |
v: tf.Tensor [C] -> v': tf.Tensor [1, C, 1, 1] | |
elif NHWC, | |
v: tf.Tensor [C] -> v': tf.Tensor [C] | |
""" | |
def _broadcast(v: tf.Tensor): | |
if (v is not None and len(v.shape) != ndims | |
and reduction_axes != list(range(ndims - 1))): | |
return tf.reshape(v, broadcast_shape) | |
return v | |
return _broadcast | |
def _gen_virtual_batch( | |
self, inputs: tf.Tensor) -> Tuple[tf.Tensor, Callable]: | |
if self.virtual_batch_size is not None: | |
# Virtual batches (aka ghost batches) can be simulated by | |
# reshaping the Tensor and reusing | |
# the existing batch norm implementation | |
original_shape = [-1] + inputs.shape.as_list()[1:] | |
expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] | |
# Will cause errors if virtual_batch_size | |
# does not divide the batch size | |
inputs = tf.reshape(inputs, expanded_shape) | |
def undo_virtual_batching(outputs): | |
outputs = tf.reshape(outputs, original_shape) | |
return outputs | |
return inputs, undo_virtual_batching | |
else: | |
return inputs, None | |
def _take_adjustment( | |
self, | |
inputs: tf.Tensor, | |
scale: tf.Tensor, | |
offset: tf.Tensor, | |
training: bool): | |
if self.adjustment: | |
adj_scale, adj_bias = self.adjustment(tf.shape(inputs)) | |
# Adjust only during training. | |
adj_scale = smart_cond( | |
training, | |
lambda: adj_scale, | |
lambda: tf.ones_like(adj_scale)) | |
adj_bias = smart_cond( | |
training, | |
lambda: adj_bias, | |
lambda: tf.zeros_like(adj_bias)) | |
scale, offset = _compose_transforms( | |
adj_scale, adj_bias, scale, offset) | |
return scale, offset | |
def _take_virtual_batch(self, mean, variance): | |
if self.virtual_batch_size is not None: | |
# This isn't strictly correct since in ghost batch norm, | |
# you are supposed to sequentially update the moving_mean | |
# and moving_variance with each sub-batch. However, | |
# since the moving statistics are only | |
# used during evaluation, it is more efficient to just update | |
# in one step and should not make a significant difference | |
# in the result. | |
new_mean = tf.reduce_mean( | |
mean, axis=1, keepdims=True) | |
new_variance = tf.reduce_mean( | |
variance, axis=1, keepdims=True) | |
else: | |
new_mean, new_variance = mean, variance | |
return new_mean, new_variance | |
def _take_renorm( | |
self, | |
new_mean: tf.Tensor, new_variance: tf.Tensor, | |
scale: tf.Tensor, offset: tf.Tensor, | |
training: bool, inputs_size: int, | |
_broadcast: Callable): | |
tmp = self._renorm_correction_and_moments( | |
new_mean, | |
new_variance, | |
training, | |
inputs_size) | |
# When training, the normalized values (say, x) | |
# will be transformed as | |
# - x * gamma + beta | |
# without renorm, and | |
# - (x * r + d) * gamma + beta | |
# = x * (r * gamma) + (d * gamma + beta) | |
# with renorm. | |
r, d, new_mean, new_variance = tmp | |
r = _broadcast(tf.stop_gradient(r, name='renorm_r')) | |
d = _broadcast(tf.stop_gradient(d, name='renorm_d')) | |
scale, offset = _compose_transforms(r, d, scale, offset) | |
return scale, offset | |
def untraining_stats(self) -> Tuple[tf.Tensor, tf.Tensor]: | |
return self.moving_mean, self.moving_variance | |
def training_stats( | |
self, | |
inputs: tf.Tensor, | |
scale: tf.Tensor, | |
offset: tf.Tensor, | |
reduction_axes: list, | |
training: bool, | |
_broadcast: Callable) -> Tuple[tf.Tensor, tf.Tensor]: | |
# deal with adjustment | |
scale, offset = self._take_adjustment( | |
inputs, scale, offset, training) | |
# Some of the computations here are not necessary | |
# when training==False | |
# but not a constant. However, this makes the code simpler. | |
keep_dims = (self.virtual_batch_size is not None | |
or len(self.axis) > 1) | |
mean, variance = self._moments( | |
tf.cast(inputs, self._param_dtype), | |
reduction_axes, | |
keep_dims=keep_dims) | |
moving_mean = self.moving_mean | |
moving_variance = self.moving_variance | |
mean = smart_cond( | |
training, | |
lambda: mean, | |
lambda: tf.convert_to_tensor(moving_mean)) | |
variance = smart_cond( | |
training, | |
lambda: variance, | |
lambda: tf.convert_to_tensor(moving_variance)) | |
# take virtual batch | |
new_mean, new_variance = self._take_virtual_batch( | |
mean, variance) | |
if self._support_zero_size_input(): | |
inputs_size = tf.size(inputs) | |
else: | |
inputs_size = None | |
# take renorm | |
if self.renorm: | |
scale, offset = self._take_renorm( | |
new_mean, new_variance, | |
scale, offset, | |
training, inputs_size, | |
_broadcast) | |
def _do_update(var: tf.Tensor, value: tf.Tensor): | |
"""Compute the updates for mean and variance.""" | |
return self._assign_moving_average( | |
var, value, self.momentum, inputs_size) | |
def mean_update(): | |
def true_branch(): return _do_update( | |
self.moving_mean, new_mean) | |
def false_branch(): return self.moving_mean | |
return smart_cond(training, true_branch, false_branch) | |
def variance_update(): | |
"""Update the moving variance.""" | |
def true_branch_renorm(): | |
# We apply epsilon as part of | |
# the moving_stddev to mirror the training code path. | |
moving_stddev = _do_update( | |
self.moving_stddev, | |
tf.math.sqrt(new_variance + self.epsilon)) | |
return self._assign_new_value( | |
self.moving_variance, | |
# Apply relu in case floating point rounding causes it | |
# to go negative. | |
K.relu(moving_stddev * moving_stddev - self.epsilon)) | |
if self.renorm: | |
true_branch = true_branch_renorm | |
else: | |
def true_branch(): | |
return _do_update( | |
self.moving_variance, new_variance) | |
def false_branch(): | |
return self.moving_variance | |
return smart_cond(training, true_branch, false_branch) | |
self.add_update(mean_update) | |
self.add_update(variance_update) | |
mean = tf.cast(mean, inputs.dtype) | |
variance = tf.cast(variance, inputs.dtype) | |
return mean, variance | |
def compute_output_shape(self, input_shape): | |
return input_shape | |
def get_config(self): | |
config = { | |
'axis': self.axis, | |
'momentum': self.momentum, | |
'epsilon': self.epsilon, | |
'center': self.center, | |
'scale': self.scale, | |
'beta_initializer': | |
initializers.serialize(self.beta_initializer), | |
'gamma_initializer': | |
initializers.serialize(self.gamma_initializer), | |
'moving_mean_initializer': | |
initializers.serialize(self.moving_mean_initializer), | |
'moving_variance_initializer': | |
initializers.serialize(self.moving_variance_initializer), | |
'beta_regularizer': | |
regularizers.serialize(self.beta_regularizer), | |
'gamma_regularizer': | |
regularizers.serialize(self.gamma_regularizer), | |
'beta_constraint': constraints.serialize(self.beta_constraint), | |
'gamma_constraint': constraints.serialize(self.gamma_constraint) | |
} | |
# Only add TensorFlow-specific parameters if they are set, | |
# so as to preserve model compatibility with external Keras. | |
if self.renorm: | |
config['renorm'] = True | |
config['renorm_clipping'] = self.renorm_clipping | |
config['renorm_momentum'] = self.renorm_momentum | |
if self.virtual_batch_size is not None: | |
config['virtual_batch_size'] = self.virtual_batch_size | |
# Note: adjustment is not serializable. | |
if self.adjustment is not None: | |
logging.warning( | |
'The `adjustment` function of this`BatchNormalization` ' | |
'layer cannot be serialized and has been omitted from ' | |
'the layer config. It will not be included when ' | |
're-creating the layer from the saved config.') | |
base_config = super().get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def main(): | |
x = tf.keras.Input([32, 32, 3]) | |
y = SyncBatchNormalization()(x) | |
model = tf.keras.Model(x, y) | |
model.summary() | |
x = tf.random.normal([32, 32, 32, 3]) | |
y = model(x) | |
print(y.shape) | |
if __name__ == '__main__': | |
main() | |
tf.keras.layers.BatchNormalization().get_config() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Interesting implementation, @MokkeMeguru!
Before spending too much time on testing it, do you know whether this implementation is compatible with gradient accumulation (GA)? If you are not aware of the concept, I have made a tool for it here.
A problem with doing GA (currently) is that batch norm (BN) is not compatible with it. This is a major hurdle as a lot of people use pretrained models which often depends on BN. Hence, it would be great to find a proper solution for it.
I have seen this implemnetation of SyncBatchNormalization in tf-contrib. However, I noticed that you had an interesting argument
virtual_batch_size
, which I hope could be what I was looking for.Any ideas?