Skip to content

Instantly share code, notes, and snippets.

@fk128
Created May 21, 2021 07:44
Show Gist options
  • Save fk128/58bf1641f07c15a8928effb62ce7e7b5 to your computer and use it in GitHub Desktop.
Save fk128/58bf1641f07c15a8928effb62ce7e7b5 to your computer and use it in GitHub Desktop.
mean only and STD only batch normalization - Tensorflow 2
from tensorflow.keras import layers
import tensorflow as tf
class MeanOnlyBN(layers.BatchNormalization):
def __init__(self, **kwargs):
kwargs['scale'] = False
kwargs['center'] = False
kwargs['fused'] = False
super().__init__(**kwargs)
def _calculate_mean_and_var(self, inputs, reduction_axes, keep_dims):
return tf.reduce_mean(inputs, axis=reduction_axes, keepdims=keep_dims), 1
class StdOnlyBN(layers.BatchNormalization):
def __init__(self, **kwargs):
kwargs['scale'] = False
kwargs['center'] = False
kwargs['fused'] = False
super().__init__(**kwargs)
def _calculate_mean_and_var(self, inputs, reduction_axes, keep_dims):
return 0, tf.math.reduce_variance(inputs, axis=reduction_axes, keepdims=keep_dims)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment