Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created March 3, 2021 11:56
Show Gist options
  • Save ntakouris/3f2e785a61f974bff80c109f59eb238b to your computer and use it in GitHub Desktop.
Save ntakouris/3f2e785a61f974bff80c109f59eb238b to your computer and use it in GitHub Desktop.
class VFELayer(keras.Model):
def __init__(self, name='VFELayer', l1=32, l2=32, l3=48, dropout=0, **kwargs):
super().__init__(name=name, **kwargs)
self.bn_a = keras.layers.BatchNormalization()
self.local_mlp_a = keras.layers.Dense(l1, kernel_initializer='glorot_uniform', activation='swish')
self.bn_b = keras.layers.BatchNormalization()
self.local_mlp_b = keras.layers.Dense(l2, kernel_initializer='glorot_uniform', activation='swish')
self.bn_g = keras.layers.BatchNormalization()
self.global_mlp = keras.layers.Dense(l3, kernel_initializer='glorot_uniform', activation='swish')
self.pool = keras.layers.GlobalMaxPooling1D()
self.dropout = keras.layers.Dropout(dropout)
def call(self, x, **kwargs):
x = self.dropout(x, **kwargs)
x = self.bn_a(x, **kwargs)
x = self.local_mlp_a(x, **kwargs)
x = self.bn_b(x, **kwargs)
x = self.local_mlp_b(x, **kwargs)
x_global = self.bn_g(x, **kwargs)
x_global = self.pool(self.global_mlp(x_global), **kwargs)
x_global = tf.tile(tf.expand_dims(x_global, 1), [1, K.shape(x)[1], 1])
return tf.concat([x, x_global], -1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment