Last active
March 2, 2023 00:28
-
-
Save todpole3/13d263e1f898e99f8f2794c9a2b6e59d to your computer and use it in GitHub Desktop.
Calculate the total number of paramters in a Transformer Decoder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Calculate the total number of parameters in a Transformer Decoder. | |
Usage: | |
# OPT 125M | |
python3 transformer_decoder_parameters.py --num-decoder-layers 12 --hidden-dim 768 --num-heads 12 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True | |
# OPT 350M | |
python3 transformer_decoder_parameters.py --num-decoder-layers 24 --hidden-dim 1024 --num-heads 16 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True | |
# OPT 1.3B | |
python3 transformer_decoder_parameters.py --num-decoder-layers 24 --hidden-dim 2048 --num-heads 32 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True | |
# OPT 2.7B | |
python3 transformer_decoder_parameters.py --num-decoder-layers 32 --hidden-dim 2560 --num-heads 32 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True | |
# OPT 6.7B | |
python3 transformer_decoder_parameters.py --num-decoder-layers 32 --hidden-dim 4096 --num-heads 32 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True | |
# OPT 13B | |
python3 transformer_decoder_parameters.py --num-decoder-layers 40 --hidden-dim 5120 --num-heads 40 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True | |
# OPT 30B | |
python3 transformer_decoder_parameters.py --num-decoder-layers 48 --hidden-dim 7168 --num-heads 56 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True | |
# OPT 66B | |
python3 transformer_decoder_parameters.py --num-decoder-layers 64 --hidden-dim 9216 --num-heads 72 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True | |
# OPT 175B | |
python3 transformer_decoder_parameters.py --num-decoder-layers 96 --hidden-dim 12288 --num-heads 96 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True | |
q_size: 1572992 | |
k_size: 1572992 | |
v_size: 1572992 | |
out_size: 150994944 | |
attn_size: 604016640 | |
decoder layer size: 1812086784 | |
decoder size (without embeddings): 173B | |
embedding_layer_size: 642M | |
decoder size: 174B | |
""" | |
import fire | |
def pretty_print_large_num(n): | |
if n < 1e9: | |
return str(n // 1000000) + "M" | |
elif n < 1e10: | |
return f"{round(n / 1000000000, 1):.1f}" + "B" | |
else: | |
return str(n // 1000000000) + "B" | |
# FFN Module | |
class FFNModuleParamCount(): | |
@classmethod | |
def count(cls, model_dim, ffn_dim): | |
return 2 * model_dim * ffn_dim + model_dim + ffn_dim # the latter 2 are bias terms | |
# Self-attention Module | |
class SelfAttentionParamCount(): | |
@classmethod | |
def count(cls, hidden_dim, head_size, num_heads): | |
q_dim = hidden_dim | |
k_dim = hidden_dim | |
v_dim = hidden_dim | |
q_size = q_dim * head_size + head_size | |
k_size = k_dim * head_size + head_size | |
v_size = v_dim * head_size + head_size | |
print(f"q_size: {q_size}") | |
print(f"k_size: {k_size}") | |
print(f"v_size: {v_size}") | |
out_size = head_size * num_heads * hidden_dim | |
attn_size = num_heads * (q_size + k_size + v_size) + out_size | |
print(f"out_size: {out_size}") | |
print(f"attn_size: {attn_size}") | |
return attn_size | |
# Decoder layer | |
class DecoderLayerParamCount(): | |
def __init__(self, hidden_dim=896, num_heads=16, head_size=None, ffn_dim=None) -> None: | |
self.hidden_dim = hidden_dim | |
self.ffn_dim = 4 * hidden_dim if ffn_dim is None else ffn_dim | |
self.num_heads = num_heads | |
if head_size is None: | |
assert hidden_dim % num_heads == 0 | |
self.head_size = int(hidden_dim / num_heads) if head_size is None else head_size | |
@property | |
def count(self): | |
mh_attn_size = SelfAttentionParamCount.count(self.hidden_dim, self.head_size, self.num_heads) | |
layer_norm_size = self.hidden_dim * 2 | |
ffn_size = FFNModuleParamCount.count(self.hidden_dim, self.ffn_dim) | |
layer_norm_size = self.hidden_dim * 2 | |
return mh_attn_size + layer_norm_size + ffn_size + layer_norm_size | |
# Decoder | |
class DecoderParamCount(): | |
def __init__(self, num_decoder_layers=12, hidden_dim=896, num_heads=16, head_size=None, ffn_dim=None, vocab_size=None, emb_dim=None, sequence_len=2048, use_learned_pos_emb=False) -> None: | |
self.decoder_layer_param_count = DecoderLayerParamCount( | |
hidden_dim=hidden_dim, | |
ffn_dim=ffn_dim, | |
num_heads=num_heads, | |
head_size=head_size | |
) | |
self.num_decoder_layers = num_decoder_layers | |
self.vocab_size = vocab_size | |
self.emb_dim = hidden_dim if emb_dim is None else emb_dim | |
self.sequence_len = sequence_len | |
self.use_learned_pos_emb = use_learned_pos_emb | |
def count(self, include_embedding_param_count=False): | |
decoder_layer_size = self.decoder_layer_param_count.count | |
print(f"decoder layer size: {decoder_layer_size}") | |
count = self.num_decoder_layers * decoder_layer_size | |
print(f"decoder size (without embeddings): {pretty_print_large_num(count)}") | |
if include_embedding_param_count: | |
embedding_layer_size = self.emb_dim * self.vocab_size | |
if self.use_learned_pos_emb: | |
pos_emb_size = self.sequence_len * self.emb_dim | |
embedding_layer_size += pos_emb_size | |
print(f"embedding_layer_size: {pretty_print_large_num(embedding_layer_size)}") | |
count += embedding_layer_size | |
print(f"decoder size: {pretty_print_large_num(count)}") | |
return count | |
def main(num_decoder_layers: int, hidden_dim: int, num_heads: int, head_size: int=None, ffn_dim: int=None, vocab_size: int=None, emb_dim: int=None, sequence_len=2048, use_learned_pos_emb: bool=False): | |
decoder_param_count = DecoderParamCount( | |
num_decoder_layers=num_decoder_layers, | |
hidden_dim=hidden_dim, | |
num_heads=num_heads, | |
head_size=head_size, | |
ffn_dim=ffn_dim, | |
vocab_size=vocab_size, | |
emb_dim=emb_dim, | |
sequence_len=sequence_len, | |
use_learned_pos_emb=use_learned_pos_emb | |
) | |
decoder_param_count.count(include_embedding_param_count=True) | |
if __name__ == "__main__": | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment