Skip to content

Instantly share code, notes, and snippets.

@MokkeMeguru
Last active November 6, 2022 19:12
Show Gist options
  • Save MokkeMeguru/35af0c7ddba511f6a268e7c78fdba2d6 to your computer and use it in GitHub Desktop.
Save MokkeMeguru/35af0c7ddba511f6a268e7c78fdba2d6 to your computer and use it in GitHub Desktop.
""" MY Batchnormalization layer
ref & thanks.
https://github.com/tensorflow/tensorflow/issues/18222
https://github.com/jkyl/biggan-deep/blob/master/src/custom_layers/batch_normalization.py
https://github.com/tensorflow/community/blob/master/rfcs/20181016-replicator.md#global-batch-normalization
https://github.com/Apm5/tensorflow_2.0_tutorial/blob/master/CNN/BatchNormalization.py
WARN: This layer cannot accept the variable fused=True,
We need select fused=False at constructor.
IF you select fused=True|None, this layer attributes as same as official BatchNormalization.
"""
from typing import Callable, List, Tuple, Union
import tensorflow as tf
from tensorflow import distribute, dtypes
from tensorflow.keras import constraints, initializers, layers, regularizers
from tensorflow.keras.layers import InputSpec
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.utils.tf_utils import constant_value, smart_cond
from tensorflow.python.ops.nn import fused_batch_norm
from tensorflow.python.platform import tf_logging as logging
# distribution_strategy_context, K, smart_cond, constant_value,
# base_layer_utils, logging, fused_batch_norm etc...
#
# are hidden values.
# So, this layer will be unusable.
Initializer = initializers.Initializer
Regularizer = regularizers.Regularizer
Constraint = constraints.Constraint
Layer = layers.Layer
Strategy = distribute.Strategy
def _compose_transforms(
scale: tf.Tensor, offset: tf.Tensor,
then_scale, then_offset):
if then_scale is not None:
scale *= then_scale
offset *= then_scale
if then_offset is not None:
offset += then_offset
return (scale, offset)
class SyncBatchNormalization(Layer):
"""
args:
- axis: Union[int, list]
- momentum: float
- epsilon: float
- center: bool
- scale: bool
- beta_initializer: Union[str, Initializer]
- gamma_initializer: Union[str, Initializer]
- moving_mean_initializer: Union[str, Initializer]
- moving_variance_initializer: Union[str, Initializer]
- beta_regularizer: Union[str, Regularizer]
- gamma_regularizer: Union[str, Regularizer]
- beta_constraint: Union[str, Constraint]
- gamma_constraint: Union[str, Constraint]
- renorm: bool
- renorm_clipping
- renorm_momentum: float
- fused: Union[bool, None]
default False (if true or None, sync batchnorm will be disable)
- trainable: bool
- virtual_batch_size: int
- adjustment
- name: str
- **kwargs
"""
_USE_V2_BEHAVIOR = True
def __init__(self,
axis: Union[int, list] = -1,
momentum: float = 0.99,
epsilon: float = 1e-3,
center: bool = True,
scale: bool = True,
beta_initializer: Union[str, Initializer] = 'zeros',
gamma_initializer: Union[str, Initializer] = 'ones',
moving_mean_initializer: Union[str, Initializer] = 'zeros',
moving_variance_initializer: Union[str, Initializer] = 'ones',
beta_regularizer: Union[str, Regularizer] = None,
gamma_regularizer: Union[str, Regularizer] = None,
beta_constraint: Union[str, Constraint] = None,
gamma_constraint: Union[str, Constraint] = None,
renorm: bool = False,
renorm_clipping=None,
renorm_momentum: float = 0.99,
fused: Union[bool, None] = None,
trainable: bool = True,
virtual_batch_size: int = None,
adjustment=None,
name: str = None,
**kwargs):
super(SyncBatchNormalization, self).__init__(
name=name,
**kwargs)
if isinstance(axis, (list, tuple)):
self.axis = axis[:]
elif isinstance(axis, int):
self.axis = axis
else:
raise TypeError('Excepted an int or list/tuple of ints for the'
'argument \'axis\', but received: {}'.format(axis))
self.momentum = momentum
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.moving_mean_initializer = initializers.get(
moving_mean_initializer)
self.moving_variance_initializer = initializers.get(
moving_variance_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
self.renorm = renorm
self.virtual_batch_size = virtual_batch_size
self.adjustment = adjustment
if self._USE_V2_BEHAVIOR:
if fused:
self._raise_if_fused_cannot_be_used()
# We leave fused as None if self._fused_can_be_used()==True,
# since we still may set it to False in self.build()
# if the input rank is not 4.
elif fused is None and not self._fused_can_be_used():
fused = False
elif fused is None:
fused = True
self.supports_masking = True
self.fused = fused
self._bessels_correction_test_only = True
self._trainable_var = None
self.trainable = trainable
if renorm:
renorm_clipping = renorm_clipping or {}
keys = ['rmax', 'rmin', 'dmax']
if set(renorm_clipping) - set(keys):
raise ValueError('renorm_clipping %s contains keys not in %s' %
(renorm_clipping, keys))
self.renorm_clipping = renorm_clipping
self.renorm_momentum = renorm_momentum
def build(self, input_shape: Union[list, tf.TensorShape]):
"""
args:
- input_shape: list
example. [None, H, W, C] = [None, 32, 32, 3] (cifer 10)
"""
input_shape = tf.TensorShape(input_shape)
if not input_shape.ndims:
raise ValueError('Input has undefined rank:', input_shape)
ndims = len(input_shape)
# Convert axis to list and resolve negatives
if isinstance(self.axis, int):
self.axis = [self.axis]
for idx, x in enumerate(self.axis):
if x < 0:
self.axis[idx] = ndims + x
# Validate axes
self._validate_axis(ndims)
self._validate_virtual_batch_size()
self._validate_fused(ndims)
self._setup_fuse()
param_shape = self._validate_dim(ndims, input_shape)
if self.scale:
self.gamma = self.add_weight(
name='gamma',
shape=param_shape,
dtype=self._param_dtype,
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
trainable=True,
experimental_autocast=False)
else:
self.gamma = None
if self.fused:
self._gamma_const = K.constant(
1.0, dtype=self._param_dtype, shape=param_shape)
if self.center:
self.beta = self.add_weight(
name='beta',
shape=param_shape,
dtype=self._param_dtype,
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
trainable=True,
experimental_autocast=False)
else:
self.beta = None
if self.fused:
self._beta_const = K.constant(
0.0, dtype=self._param_dtype, shape=param_shape)
self._setup_movings(param_shape)
self.built = True
def call(self, inputs: tf.Tensor, training=None, **kwargs):
training = self._get_training_value(training)
if self.fused:
return self.fused_normalize(inputs, training, **kwargs)
else:
return self.normalize(inputs, training, **kwargs)
def normalize(self, inputs: tf.Tensor, training):
inputs, undo_virtual_batching = self._gen_virtual_batch(inputs)
# Compute the axes along which to reduce the mean / variance
input_shape = inputs.shape
ndims = len(input_shape)
reduction_axes = [i for i in range(ndims) if i not in self.axis]
if self.virtual_batch_size is not None:
del reduction_axes[1] # Do not reduce along virtual batch dim
# Broadcasting only necessary for single-axis batch norm
# where the axis is not the last dimension
broadcast_shape = [1] * ndims
broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
_broadcast = self._gen_broadcast(
ndims, reduction_axes, broadcast_shape)
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
training_value = constant_value(training)
if training_value is False:
mean, variance = self.untraining_stats()
else:
mean, variance = self.training_stats(
inputs,
scale,
offset,
reduction_axes,
training,
_broadcast
)
if offset is not None:
offset = tf.cast(offset, inputs.dtype)
if scale is not None:
scale = tf.cast(scale, inputs.dtype)
# TODO(reedwm): Maybe do math in float32
# if given float16 inputs, if doing
# math in float16 hurts validation accuracy of
# popular models like resnet.
outputs = tf.nn.batch_normalization(
inputs, _broadcast(mean), _broadcast(variance),
offset, scale, self.epsilon)
# If some components of the shape got lost
# due to adjustments, fix that.
outputs.set_shape(input_shape)
if self.virtual_batch_size is not None:
outputs = undo_virtual_batching(outputs)
return outputs
def fused_normalize(self, inputs: tf.Tensor, training, **kwargs):
inputs, undo_virtual_batching = self._gen_virtual_batch(inputs)
outputs = self._fused_batch_norm(inputs, training=training)
if self.virtual_batch_size is not None:
# Currently never reaches here
# since fused_batch_norm does not support virtual batching
outputs = undo_virtual_batching(outputs)
return outputs
@property
def trainable(self) -> bool:
return self._trainable
@trainable.setter
def trainable(self, value: bool):
self._trainable = value
if self._trainable_var is not None:
self._trainable_var.update_value(value)
def _get_trainable_var(self) -> tf.Tensor:
if self._trainable_var is None:
self._trainable_var = K.freezable_variable(
self._trainable, name=self.name + '_trainable')
return self._trainable_var
def _raise_if_fused_cannot_be_used(self):
"""Raises a ValueError if fused implementation cannot be used.
In addition to the checks done in this function,
the input tensors rank must be 4.
The input rank check can only be done once the input shape is known.
"""
# Currently fused batch norm doesn't support renorm.
# It also only supports a channel dimension on axis 1 or 3,
# when no virtual batch size or adjustment is used.
if self.renorm:
raise ValueError('Passing both fused=True and renorm=True is '
'unsupported')
axis = [self.axis] if isinstance(self.axis, int) else self.axis
# Axis -3 is equivalent to 1, and axis -1 is equivalent to 3,
# because the input rank is required to be 4 (which is checked later).
if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3):
raise ValueError('Passing fused=True is only supported '
'when axis is 1 or 3')
if self.virtual_batch_size is not None:
raise ValueError('Passing fused=True is unsupported when '
'virtual_batch_size is specified.')
if self.adjustment is not None:
raise ValueError('Passing fused=True is unsupported when '
'adjustment is specified.')
def _fused_can_be_used(self):
try:
self._raise_if_fused_cannot_be_used()
return True
except ValueError:
return False
@property
def _param_dtype(self) -> dtypes.DType:
# Raise parameters of fp16 batch norm to fp32
if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
return dtypes.float32
else:
return self.dtype or dtypes.float32
def _support_zero_size_input(self) -> Strategy:
return distribution_strategy_context.has_strategy() and getattr(
distribution_strategy_context.get_strategy().extended,
'experimental_enable_get_next_as_optional', False)
def _validate_axis(self, ndims: int):
for x in self.axis:
if x < 0 or x >= ndims:
raise ValueError('Invalid axis: {}'.format(x))
if len(self.axis) != len(set(self.axis)):
raise ValueError('Duplicate axis: {}'.format(self.axis))
def _validate_virtual_batch_size(self):
if self.virtual_batch_size is not None:
if self.virtual_batch_size <= 0:
raise ValueError(
'virtual_batch_size must be a positive integer that '
'devides the true batch siz of the input Tensor')
# If using virtual batches, the first dimension must be the batch
# dimension and cannot be the batch norm axis
if 0 in self.axis:
raise ValueError(
'When using virtual_batch_size, the batch dimension '
'must be 0 and thus axis cannot include 0')
if self.adjustment is not None:
raise ValueError(
'When using virtual_batch_size, adjustment cannot '
'be specified')
def _validate_fused(self, ndims: int):
if self.fused in (None, True):
# TODO(yaozhang): if input is not 4D, reshape it to 4D and
# reshape the output back to its original shape accordingly.
if self._USE_V2_BEHAVIOR:
if self.fused is None:
self.fused = (ndims == 4)
elif self.fused and ndims != 4:
raise ValueError(
'Batch normalization layers with fused=True only '
'support 4D input tensors.')
else:
# TODO(chrisying): fused batch norm is currently not supported for
# multi-axis batch norm and by extension virtual batches.
# In some cases, it might be possible to use fused batch norm
# but would require reshaping
# the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
# particularly tricky. A compromise might be to just support
# the most common use case (turning 5D w/ virtual batch to NCHW)
# self.feature_dim = input_shape[self.axis]
assert self.fused is not None
self.fused = (ndims == 4 and self._fused_can_be_used())
def _setup_fuse(self):
if self.fused:
if self.axis == [1]:
self._data_format = 'NCHW'
elif self.axis == [3]:
self._data_format = 'NHWC'
else:
raise ValueError(
'Unsupported axis, fused batch norm only supports '
'axis == [1] or axis == [3]')
def _validate_dim(self, ndims: int, input_shape: tf.TensorShape) -> List:
axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
# Note:
# {x: [None, 32, 32, 1][x] for x in [3]}
# => {3: 1}
for x in axis_to_dim:
if axis_to_dim[x] is None:
raise ValueError(
'Input has undefined `axis` dimension. Input shape: ',
input_shape)
self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)
if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
# Single axis batch norm (most common/default use-case)
param_shape = (list(axis_to_dim.values())[0],)
else:
# Parameter shape is the original shape
# but with 1 in all non-axis dims
param_shape = [axis_to_dim[i] if i in axis_to_dim
else 1 for i in range(ndims)]
if self.virtual_batch_size is not None:
# When using virtual batches, add an extra dim at index 1
param_shape.insert(1, 1)
for idx, x in enumerate(self.axis):
# Account for added dimension
self.axis[idx] = x + 1
return param_shape
def _setup_movings(self, param_shape):
try:
# Disable variable partitioning
# when creating the moving mean and variance
if hasattr(self, '_scope') and self._scope:
partitioner = self._scope.partitioner
self._scope.set_partitioner(None)
else:
partitioner = None
self.moving_mean = self.add_weight(
name='moving_mean',
shape=param_shape,
dtype=self._param_dtype,
initializer=self.moving_mean_initializer,
synchronization=tf.VariableSynchronization.ON_READ,
trainable=False,
aggregation=tf.VariableAggregation.MEAN,
experimental_autocast=False)
self.moving_variance = self.add_weight(
name='moving_variance',
shape=param_shape,
dtype=self._param_dtype,
initializer=self.moving_variance_initializer,
synchronization=tf.VariableSynchronization.ON_READ,
trainable=False,
aggregation=tf.VariableAggregation.MEAN,
experimental_autocast=False)
if self.renorm:
# In batch renormalization we track the inference moving stddev
# instead of the moving variance
# to more closely align with the paper.
def moving_stddev_initializer(*args, **kwargs):
return tf.math.sqrt(
self.moving_variance_initializer(*args, **kwargs))
with distribution_strategy_context.get_strategy(
).extended.colocate_vars_with(self.moving_variance):
self.moving_stddev = self.add_weight(
name='moving_stddev',
shape=param_shape,
dtype=self._param_dtype,
initializer=moving_stddev_initializer,
synchronization=tf.VariableSynchronization.ON_READ,
trainable=False,
aggregation=tf.VariableAggregation.MEAN,
experimental_autocast=False)
# Create variables to maintain the moving mean and
# standard deviation.
# These are used in training and thus are different from
# the moving averages above. The renorm variables are colocated
# with moving_mean and moving_stddev.
# NOTE: below, the outer `with device` block causes
# the current devicse tack to be cleared.
# The nested ones use a `lambda` to set the desired device
# and ignore any devices that may be set by the custom getter.
def _renorm_variable(name,
shape,
initializer=tf.zeros_initializer()):
"""Create a renorm variable."""
var = self.add_weight(
name=name,
shape=shape,
dtype=self._param_dtype,
initializer=initializer,
synchronization=tf.VariableSynchronization.ON_READ,
trainable=False,
aggregation=tf.VariableAggregation.MEAN,
experimental_autocast=False)
return var
with distribution_strategy_context.get_strategy(
).extended.colocate_vars_with(self.moving_mean):
self.renorm_mean = _renorm_variable(
'renorm_mean', param_shape,
self.moving_mean_initializer)
with distribution_strategy_context.get_strategy(
).extended.colocate_vars_with(self.moving_stddev):
self.renorm_stddev = _renorm_variable(
'renorm_stddev', param_shape,
moving_stddev_initializer)
finally:
if partitioner:
self._scope.set_partitioner(partitioner)
def _assign_moving_average(self,
variable: tf.Tensor,
value: tf.Tensor,
momentum: float,
inputs_size: int):
with K.name_scope('AssignMovingAvg') as scope:
decay = tf.convert_to_tensor(1.0 - momentum, name='decay')
if decay.dtype != variable.dtype.base_dtype:
decay = tf.cast(decay, variable.dtype.base_dtype)
update_delta = (
variable - tf.cast(value, variable.dtype)) * decay
if inputs_size is not None:
update_delta = tf.where(
inputs_size > 0, update_delta,
K.zeros_like(update_delta))
return variable.assign_sub(update_delta, name=scope)
def _assign_new_value(self, variable: tf.Tensor, value: tf.Tensor):
with K.name_scope('AssignNewValue') as scope:
return variable.assign(value, name=scope)
def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm."""
beta = self.beta if self.center else self._beta_const
gamma = self.gamma if self.scale else self._gamma_const
# TODO(b/129279393): Support zero batch input in non
# DistributionStrategy code as well.
if self._support_zero_size_input():
inputs_size = tf.size(inputs)
else:
inputs_size = None
def _fused_batch_norm_training():
return fused_batch_norm(
inputs,
gamma,
beta,
epsilon=self.epsilon,
data_format=self._data_format)
def _fused_batch_norm_inference():
return fused_batch_norm(
inputs,
gamma,
beta,
mean=self.moving_mean,
variance=self.moving_variance,
epsilon=self.epsilon,
is_training=False,
data_format=self._data_format)
output, mean, variance = smart_cond(
training, _fused_batch_norm_training, _fused_batch_norm_inference)
if not self._bessels_correction_test_only:
# Remove Bessel's correction to be consistent
# with non-fused batch norm.
# Note that the variance computed by fused batch norm is
# with Bessel's correction.
sample_size = tf.math.cast(
tf.size(inputs) / tf.size(variance), variance.dtype)
factor = ((sample_size - tf.math.cast(1.0, variance.dtype))
/ sample_size)
variance *= factor
training_value = constant_value(training)
if training_value is None:
momentum = smart_cond(training,
lambda: self.momentum,
lambda: 1.0)
else:
momentum = tf.convert_to_tensor(self.momentum)
if training_value or training_value is None:
def mean_update():
return self._assign_moving_average(
self.moving_mean,
mean,
momentum,
inputs_size)
def variance_update():
"""Update self.moving_variance with
the most recent data point."""
if self.renorm:
# We apply epsilon as part of the moving_stddev to mirror
# the training code path.
moving_stddev = self._assign_moving_average(
self.moving_stddev,
tf.math.sqrt(variance + self.epsilon),
momentum, inputs_size)
return self._assign_new_value(
self.moving_variance,
# Apply relu in case floating point rounding causes
# it to go negative.
K.relu(moving_stddev * moving_stddev - self.epsilon))
else:
return self._assign_moving_average(
self.moving_variance,
variance,
momentum, inputs_size)
self.add_update(mean_update)
self.add_update(variance_update)
return output
def _renorm_correction_and_moments(
self,
mean: tf.Tensor,
variance: tf.Tensor,
training: bool,
inputs_size: int):
"""Returns the correction and update values for renorm."""
stddev = tf.math.sqrt(variance + self.epsilon)
# Compute the average mean and standard deviation, as if they were
# initialized with this batch's moments.
renorm_mean = self.renorm_mean
# Avoid divide by zero early on in training.
renorm_stddev = tf.math.maximum(self.renorm_stddev,
tf.math.sqrt(self.epsilon))
# Compute the corrections for batch renorm.
r = stddev / renorm_stddev
d = (mean - renorm_mean) / renorm_stddev
# Ensure the corrections use pre-update moving averages.
with tf.control_dependencies([r, d]):
mean = tf.identity(mean)
stddev = tf.identity(stddev)
rmin, rmax, dmax = [self.renorm_clipping.get(key)
for key in ['rmin', 'rmax', 'dmax']]
if rmin is not None:
r = tf.math.maximum(r, rmin)
if rmax is not None:
r = tf.math.minimum(r, rmax)
if dmax is not None:
d = tf.math.maximum(d, -dmax)
d = tf.math.minimum(d, dmax)
# When not training, use r=1, d=0.
r = smart_cond(
training,
lambda: r,
lambda: tf.ones_like(r))
d = smart_cond(
training,
lambda: d,
lambda: tf.zeros_like(d))
def _update_renorm_variable(
var: tf.Tensor, value: tf.Tensor, inputs_size: int):
"""Updates a moving average and weight,
returns the unbiased value."""
value = tf.identity(value)
def _do_update():
"""Updates the var, returns the updated value."""
new_var = self._assign_moving_average(
var, value, self.renorm_momentum, inputs_size)
return new_var
def _fake_update():
return tf.identity(var)
return smart_cond(training, _do_update, _fake_update)
# TODO(yuefengz): colocate the operations
update_new_mean = _update_renorm_variable(
self.renorm_mean, mean, inputs_size)
update_new_stddev = _update_renorm_variable(
self.renorm_stddev, stddev, inputs_size)
# Update the inference mode moving averages with the batch value.
with tf.control_dependencies([update_new_mean, update_new_stddev]):
out_mean = tf.identity(mean)
out_variance = tf.identity(variance)
return (r, d, out_mean, out_variance)
def _moments(self,
inputs: tf.Tensor,
reduction_axes: list, keep_dims: bool):
# not syncbatch
# mean, variance = tf.nn.moments(
# inputs,
# reduction_axes,
# keep_dims=keep_dims)
# @MokkeMeguru add syncbatch
ctx = tf.distribute.get_replica_context()
n = ctx.num_replicas_in_sync
mean, mean_sq = ctx.all_reduce(
tf.distribute.ReduceOp.SUM,
[tf.reduce_mean(
inputs,
axis=reduction_axes, keepdims=keep_dims) / n,
tf.reduce_mean(
tf.square(inputs),
axis=reduction_axes, keepdims=keep_dims) / n]
)
variance = mean_sq - mean ** 2
# TODO(b/129279393): Support zero batch input in non
# DistributionStrategy code as well.
if self._support_zero_size_input():
inputs_size = tf.size(inputs)
mean = tf.where(inputs_size > 0, mean, K.zeros_like(mean))
variance = tf.where(inputs_size > 0, variance,
K.zeros_like(variance))
return mean, variance
def _get_training_value(self,
training: Union[bool, None] = None) -> bool:
if training is None:
training = K.learning_phase()
tf.print(
'keras says:' , training)
if self._USE_V2_BEHAVIOR:
if isinstance(training, int):
training = bool(training)
if base_layer_utils.is_in_keras_graph():
training = tf.logical_and(
training, self._get_trainable_var())
else:
training = tf.logical_and(training, self.trainable)
tf.print(
'at last: ', training)
return training
def _gen_broadcast(
self,
ndims: int,
reduction_axes: list,
broadcast_shape: list) -> tf.Tensor:
"""example:
if NCHW,
v: tf.Tensor [C] -> v': tf.Tensor [1, C, 1, 1]
elif NHWC,
v: tf.Tensor [C] -> v': tf.Tensor [C]
"""
def _broadcast(v: tf.Tensor):
if (v is not None and len(v.shape) != ndims
and reduction_axes != list(range(ndims - 1))):
return tf.reshape(v, broadcast_shape)
return v
return _broadcast
def _gen_virtual_batch(
self, inputs: tf.Tensor) -> Tuple[tf.Tensor, Callable]:
if self.virtual_batch_size is not None:
# Virtual batches (aka ghost batches) can be simulated by
# reshaping the Tensor and reusing
# the existing batch norm implementation
original_shape = [-1] + inputs.shape.as_list()[1:]
expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]
# Will cause errors if virtual_batch_size
# does not divide the batch size
inputs = tf.reshape(inputs, expanded_shape)
def undo_virtual_batching(outputs):
outputs = tf.reshape(outputs, original_shape)
return outputs
return inputs, undo_virtual_batching
else:
return inputs, None
def _take_adjustment(
self,
inputs: tf.Tensor,
scale: tf.Tensor,
offset: tf.Tensor,
training: bool):
if self.adjustment:
adj_scale, adj_bias = self.adjustment(tf.shape(inputs))
# Adjust only during training.
adj_scale = smart_cond(
training,
lambda: adj_scale,
lambda: tf.ones_like(adj_scale))
adj_bias = smart_cond(
training,
lambda: adj_bias,
lambda: tf.zeros_like(adj_bias))
scale, offset = _compose_transforms(
adj_scale, adj_bias, scale, offset)
return scale, offset
def _take_virtual_batch(self, mean, variance):
if self.virtual_batch_size is not None:
# This isn't strictly correct since in ghost batch norm,
# you are supposed to sequentially update the moving_mean
# and moving_variance with each sub-batch. However,
# since the moving statistics are only
# used during evaluation, it is more efficient to just update
# in one step and should not make a significant difference
# in the result.
new_mean = tf.reduce_mean(
mean, axis=1, keepdims=True)
new_variance = tf.reduce_mean(
variance, axis=1, keepdims=True)
else:
new_mean, new_variance = mean, variance
return new_mean, new_variance
def _take_renorm(
self,
new_mean: tf.Tensor, new_variance: tf.Tensor,
scale: tf.Tensor, offset: tf.Tensor,
training: bool, inputs_size: int,
_broadcast: Callable):
tmp = self._renorm_correction_and_moments(
new_mean,
new_variance,
training,
inputs_size)
# When training, the normalized values (say, x)
# will be transformed as
# - x * gamma + beta
# without renorm, and
# - (x * r + d) * gamma + beta
# = x * (r * gamma) + (d * gamma + beta)
# with renorm.
r, d, new_mean, new_variance = tmp
r = _broadcast(tf.stop_gradient(r, name='renorm_r'))
d = _broadcast(tf.stop_gradient(d, name='renorm_d'))
scale, offset = _compose_transforms(r, d, scale, offset)
return scale, offset
def untraining_stats(self) -> Tuple[tf.Tensor, tf.Tensor]:
return self.moving_mean, self.moving_variance
def training_stats(
self,
inputs: tf.Tensor,
scale: tf.Tensor,
offset: tf.Tensor,
reduction_axes: list,
training: bool,
_broadcast: Callable) -> Tuple[tf.Tensor, tf.Tensor]:
# deal with adjustment
scale, offset = self._take_adjustment(
inputs, scale, offset, training)
# Some of the computations here are not necessary
# when training==False
# but not a constant. However, this makes the code simpler.
keep_dims = (self.virtual_batch_size is not None
or len(self.axis) > 1)
mean, variance = self._moments(
tf.cast(inputs, self._param_dtype),
reduction_axes,
keep_dims=keep_dims)
moving_mean = self.moving_mean
moving_variance = self.moving_variance
mean = smart_cond(
training,
lambda: mean,
lambda: tf.convert_to_tensor(moving_mean))
variance = smart_cond(
training,
lambda: variance,
lambda: tf.convert_to_tensor(moving_variance))
# take virtual batch
new_mean, new_variance = self._take_virtual_batch(
mean, variance)
if self._support_zero_size_input():
inputs_size = tf.size(inputs)
else:
inputs_size = None
# take renorm
if self.renorm:
scale, offset = self._take_renorm(
new_mean, new_variance,
scale, offset,
training, inputs_size,
_broadcast)
def _do_update(var: tf.Tensor, value: tf.Tensor):
"""Compute the updates for mean and variance."""
return self._assign_moving_average(
var, value, self.momentum, inputs_size)
def mean_update():
def true_branch(): return _do_update(
self.moving_mean, new_mean)
def false_branch(): return self.moving_mean
return smart_cond(training, true_branch, false_branch)
def variance_update():
"""Update the moving variance."""
def true_branch_renorm():
# We apply epsilon as part of
# the moving_stddev to mirror the training code path.
moving_stddev = _do_update(
self.moving_stddev,
tf.math.sqrt(new_variance + self.epsilon))
return self._assign_new_value(
self.moving_variance,
# Apply relu in case floating point rounding causes it
# to go negative.
K.relu(moving_stddev * moving_stddev - self.epsilon))
if self.renorm:
true_branch = true_branch_renorm
else:
def true_branch():
return _do_update(
self.moving_variance, new_variance)
def false_branch():
return self.moving_variance
return smart_cond(training, true_branch, false_branch)
self.add_update(mean_update)
self.add_update(variance_update)
mean = tf.cast(mean, inputs.dtype)
variance = tf.cast(variance, inputs.dtype)
return mean, variance
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
'axis': self.axis,
'momentum': self.momentum,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer':
initializers.serialize(self.beta_initializer),
'gamma_initializer':
initializers.serialize(self.gamma_initializer),
'moving_mean_initializer':
initializers.serialize(self.moving_mean_initializer),
'moving_variance_initializer':
initializers.serialize(self.moving_variance_initializer),
'beta_regularizer':
regularizers.serialize(self.beta_regularizer),
'gamma_regularizer':
regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
# Only add TensorFlow-specific parameters if they are set,
# so as to preserve model compatibility with external Keras.
if self.renorm:
config['renorm'] = True
config['renorm_clipping'] = self.renorm_clipping
config['renorm_momentum'] = self.renorm_momentum
if self.virtual_batch_size is not None:
config['virtual_batch_size'] = self.virtual_batch_size
# Note: adjustment is not serializable.
if self.adjustment is not None:
logging.warning(
'The `adjustment` function of this`BatchNormalization` '
'layer cannot be serialized and has been omitted from '
'the layer config. It will not be included when '
're-creating the layer from the saved config.')
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def main():
x = tf.keras.Input([32, 32, 3])
y = SyncBatchNormalization()(x)
model = tf.keras.Model(x, y)
model.summary()
x = tf.random.normal([32, 32, 32, 3])
y = model(x)
print(y.shape)
if __name__ == '__main__':
main()
tf.keras.layers.BatchNormalization().get_config()
@andreped
Copy link

andreped commented Nov 6, 2022

Interesting implementation, @MokkeMeguru!

Before spending too much time on testing it, do you know whether this implementation is compatible with gradient accumulation (GA)? If you are not aware of the concept, I have made a tool for it here.

A problem with doing GA (currently) is that batch norm (BN) is not compatible with it. This is a major hurdle as a lot of people use pretrained models which often depends on BN. Hence, it would be great to find a proper solution for it.

I have seen this implemnetation of SyncBatchNormalization in tf-contrib. However, I noticed that you had an interesting argument virtual_batch_size, which I hope could be what I was looking for.

Any ideas?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment