Skip to content

Instantly share code, notes, and snippets.

@machelreid
Created December 8, 2020 06:33
Show Gist options
  • Save machelreid/f5fb09da8b3bb55fe6a4763472558791 to your computer and use it in GitHub Desktop.
Save machelreid/f5fb09da8b3bb55fe6a4763472558791 to your computer and use it in GitHub Desktop.
Get the parameters for the vanilla Transformer (Vaswani et al., 2017)
import argparse
def get_enc_params(embed_dim, ffn_dim):
return embed_dim * embed_dim * 4 + embed_dim * ffn_dim * 2 + embed_dim * 5 + ffn_dim
def get_dec_params(embed_dim, ffn_dim, encoder_embed_dim=None):
return (
embed_dim * embed_dim * 4
+ embed_dim * ffn_dim * 2
+ encoder_embed_dim * encoder_embed_dim * 4
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--enc-embed-dim", type=int)
parser.add_argument("--dec-embed-dim", type=int)
parser.add_argument("--enc-ffn-dim", type=int)
parser.add_argument("--dec-ffn-dim", type=int)
parser.add_argument("--vocab-size", type=int)
parser.add_argument("--dec-layers", type=int)
parser.add_argument("--enc-layers", type=int)
args = parser.parse_args()
print(
"{:_}".format(
get_enc_params(embed_dim=args.enc_embed_dim, ffn_dim=args.enc_ffn_dim)
* args.enc_layers
+ get_dec_params(args.dec_embed_dim, args.dec_ffn_dim, args.enc_embed_dim)
* args.dec_layers
+ args.vocab_size * args.enc_embed_dim
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment