Skip to content

Instantly share code, notes, and snippets.

@keunwoochoi
Last active November 30, 2017 15:33
Show Gist options
  • Save keunwoochoi/ab3d8e17ebce8d8f881898d6ce0b746a to your computer and use it in GitHub Desktop.
Save keunwoochoi/ab3d8e17ebce8d8f881898d6ce0b746a to your computer and use it in GitHub Desktop.
Simple batch normalization prediction fix for Keras 1.x weights
import keras.backend as K
# Load the weights
# If there are more than 1 BN, you might wanna have a proper layer
moving_mean = f['batchnormalization_1_running_mean']
moving_std = f['batchnormalization_1_running_std']
beta = f['batchnormalization_1_beta']
gamma = f['batchnormalization_2_beta']
# BE CAREFUL! Keras 1 stores std
# , while still taking var in the backend function.
moving_variance = moving_std ** 2
def BN_test_time(x):
return K.batch_normalization(
x,
moving_mean,
moving_variance,
beta,
gamma,
epsilon=self.epsilon)
def BN_output_shape(input_shape):
return input_shape
# in your model...
model.add(Lambda(BN_test_time,
output_shape=BN_output_shape))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment