Skip to content

Instantly share code, notes, and snippets.

@shubham0204
Created June 13, 2021 01:50
Show Gist options
  • Save shubham0204/22d4bd4cf260de26d3333f1cf48e0f1c to your computer and use it in GitHub Desktop.
Save shubham0204/22d4bd4cf260de26d3333f1cf48e0f1c to your computer and use it in GitHub Desktop.
hidden_dims = 128
token_mixing_mlp_dims = 64
channel_mixing_mlp_dims = 128
patch_size = 9
num_classes = 10
num_mixer_layers = 4
reshape_image_dim = 72
input_image_shape = ( 32 , 32 , 3 )
inputs = tf.keras.layers.Input( shape=input_image_shape )
# Conv2D to extract patches
patches = tf.keras.layers.Conv2D( hidden_dims , kernel_size=patch_size , strides=patch_size )( inputs )
# Resizing the patches
patches_reshape = tf.keras.layers.Reshape( ( patches.shape[ 1 ] * patches.shape[ 2 ] , patches.shape[ 3 ] ) )( patches )
x = patches_reshape
for _ in range( num_mixer_layers ):
x = mixer( x , token_mixing_mlp_dims , channel_mixing_mlp_dims )
# Classifier head
x = tf.keras.layers.LayerNormalization( epsilon=1e-6 )( x )
x = tf.keras.layers.GlobalAveragePooling1D()( x )
outputs = tf.keras.layers.Dense( num_classes , activation='softmax' )( x )
model = tf.keras.models.Model( inputs , outputs )
model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment