Last active
December 10, 2024 22:52
-
-
Save Quentin-Anthony/89fa77e62b709b9320b8c14ddeee207c to your computer and use it in GitHub Desktop.
Transformer Parameter Count
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import math | |
# Helper function to pretty-print message sizes | |
def convert_params(params): | |
if params == 0: | |
return "0" | |
size_name = ("", "K", "M", "B", "T", "P", "E", "Z", "Y") | |
i = int(math.floor(math.log(params, 1000))) | |
p = math.pow(1000, i) | |
s = round(params / p, 2) | |
return "%s %s" % (s, size_name[i]) | |
def config_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--vocab-size", "-v", | |
type=int, | |
default=51200, | |
help='Size of the vocab') | |
parser.add_argument("--hidden-size", "-hs", | |
type=int, | |
default=6144, | |
help='Dimension of the model\'s hidden size') | |
parser.add_argument("--num-attention-heads", "-a", | |
type=int, | |
default=64, | |
help='Number of attention heads used in model') | |
parser.add_argument("--sequence-length", "-s", | |
type=int, | |
default=2048, | |
help='Sequence length used for training') | |
parser.add_argument("--num-layers", "-l", | |
type=int, | |
default=44, | |
help='Number of transformer layers used in model') | |
parser.add_argument("--moe", | |
action="store_true", | |
help='Whether our model is MoE') | |
parser.add_argument("--num-experts", "-e", | |
type=int, | |
default=128, | |
help='Number of experts for MoE') | |
parser.add_argument("--expert-interval", "-ei", | |
type=int, | |
default=2, | |
help='Expert interval for MoE') | |
parser.add_argument("--topk", "-t", | |
type=int, | |
default=1, | |
help='Top k routing for MoE') | |
return parser | |
# calculates the params of a model given their hparams | |
def calc_params(args): | |
# Assumes that the embedding and unembedding are tied | |
embedding_params = args.hidden_size * args.vocab_size | |
position_embedding_params = args.hidden_size * args.sequence_length | |
# Each QKVO matrix is (hxh) | |
attention_params = 4 * args.num_layers * args.hidden_size * args.hidden_size | |
# (4*2)lh from the layernorm weights and biases for each of the QKV and mlp_in layernorms, 2h for the final layernorm. | |
layernorm_params = 8 * args.num_layers * args.hidden_size + 2 * args.hidden_size | |
#ffn_params = 12 * args.num_layers * args.hidden_size * args.hidden_size | |
if args.moe: | |
# the proportion of layers that are MoE. (e.g. every 2 for GShard) | |
num_expert_layers = args.num_layers / args.expert_interval | |
# the number of FFN params for each MoE layer | |
ffn_expert_params = 8 * num_expert_layers * args.num_experts * args.hidden_size * args.hidden_size | |
# the number of FFN params for every dense layer | |
ffn_dense_params = 8 * (args.num_layers - num_expert_layers) * args.hidden_size * args.hidden_size | |
ffn_params = ffn_expert_params + ffn_dense_params | |
# the number of gating layer params assuming it's implemented as a simple linear layer | |
gating_params = num_expert_layers * args.hidden_size * args.num_experts | |
else: | |
# two (h x 4h) FFN matrices | |
ffn_params = 8 * args.num_layers * args.hidden_size * args.hidden_size | |
total_params = embedding_params + attention_params + ffn_params + position_embedding_params + layernorm_params | |
if args.moe: | |
total_params += gating_params | |
print(f'Calculating number of parameters with training configuration: {vars(args)}\n') | |
print(f'Embedding parameters: {convert_params(embedding_params)}') | |
print(f'Attention parameters: {convert_params(attention_params)}') | |
print(f'FFN parameters: {convert_params(ffn_params)}') | |
if args.moe: | |
print(f'Gating parameters: {convert_params(gating_params)}') | |
print(f'Total Params in the Model: {convert_params(total_params)}') | |
if __name__ == "__main__": | |
print('\nExample with Fairseq-MoE 15B: python calc_transformer_params.py -l 12 -hs 768 --moe -e 512') | |
print('Example with GPT-3 175B: python calc_transformer_params.py -l 96 -hs 12288') | |
args = config_parser().parse_args() | |
calc_params(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment