Skip to content

Instantly share code, notes, and snippets.

@shubham0204
Created June 13, 2021 01:48
Show Gist options
  • Save shubham0204/7969bc861b3b49281ff3e971754ce889 to your computer and use it in GitHub Desktop.
Save shubham0204/7969bc861b3b49281ff3e971754ce889 to your computer and use it in GitHub Desktop.
# Mixer layer consisting of token mixing MLPs and channel mixing MLPs
# input shape -> ( batch_size , channels , num_patches )
# output shape -> ( batch_size , channels , num_patches )
def mixer( x , token_mixing_mlp_dims , channel_mixing_mlp_dims ):
# inputs x of are of shape ( batch_size , num_patches , channels )
# Note: "channels" is used instead of "embedding_dims"
# Add token mixing MLPs
token_mixing_out = token_mixing( x , token_mixing_mlp_dims )
# Shape of token_mixing_out -> ( batch_size , channels , num_patches )
token_mixing_out = tf.keras.layers.Permute( dims=[ 2 , 1 ] )( token_mixing_out )
# Shape of transposition -> ( batch_size , num_patches , channels )
# Add skip connection
token_mixing_out = tf.keras.layers.Add()( [ x , token_mixing_out ] )
# Add channel mixing MLPs
channel_mixing_out = channel_mixing( token_mixing_out , channel_mixing_mlp_dims )
# Shape of channel_mixing_out -> ( batch_size , num_patches , channels )
# Add skip connection
channel_mixing_out = tf.keras.layers.Add()( [ channel_mixing_out , token_mixing_out ] )
# Shape of channel_mixing_out -> ( batch_size , num_patches , channels )
return channel_mixing_out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment