Skip to content

Instantly share code, notes, and snippets.

@keuv-grvl
Created October 20, 2021 08:09
Show Gist options
  • Save keuv-grvl/0e40438248cd5b624a39db4f5afe5e16 to your computer and use it in GitHub Desktop.
Save keuv-grvl/0e40438248cd5b624a39db4f5afe5e16 to your computer and use it in GitHub Desktop.
Wrap TF op in Keras layer with lambda and closure
import tensorflow as tf
ReduceMean = lambda axis=1: tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=axis))
i = tf.keras.layers.Input(shape=(5,15)) # shape: [BATCHSIZE, 5, 15]
m = ReduceMean(i) # shape: [BATCHSIZE, 15]
# TODO would be nice to have a wrapping function for any TF op with arbitrary
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment