Skip to content

Instantly share code, notes, and snippets.

@todpole3
Last active March 2, 2023 00:28
Show Gist options
  • Save todpole3/13d263e1f898e99f8f2794c9a2b6e59d to your computer and use it in GitHub Desktop.
Save todpole3/13d263e1f898e99f8f2794c9a2b6e59d to your computer and use it in GitHub Desktop.
Calculate the total number of paramters in a Transformer Decoder
"""
Calculate the total number of parameters in a Transformer Decoder.
Usage:
# OPT 125M
python3 transformer_decoder_parameters.py --num-decoder-layers 12 --hidden-dim 768 --num-heads 12 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True
# OPT 350M
python3 transformer_decoder_parameters.py --num-decoder-layers 24 --hidden-dim 1024 --num-heads 16 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True
# OPT 1.3B
python3 transformer_decoder_parameters.py --num-decoder-layers 24 --hidden-dim 2048 --num-heads 32 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True
# OPT 2.7B
python3 transformer_decoder_parameters.py --num-decoder-layers 32 --hidden-dim 2560 --num-heads 32 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True
# OPT 6.7B
python3 transformer_decoder_parameters.py --num-decoder-layers 32 --hidden-dim 4096 --num-heads 32 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True
# OPT 13B
python3 transformer_decoder_parameters.py --num-decoder-layers 40 --hidden-dim 5120 --num-heads 40 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True
# OPT 30B
python3 transformer_decoder_parameters.py --num-decoder-layers 48 --hidden-dim 7168 --num-heads 56 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True
# OPT 66B
python3 transformer_decoder_parameters.py --num-decoder-layers 64 --hidden-dim 9216 --num-heads 72 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True
# OPT 175B
python3 transformer_decoder_parameters.py --num-decoder-layers 96 --hidden-dim 12288 --num-heads 96 --vocab-size 50272 --sequence-len 2048 --use-learned-pos-emb True
q_size: 1572992
k_size: 1572992
v_size: 1572992
out_size: 150994944
attn_size: 604016640
decoder layer size: 1812086784
decoder size (without embeddings): 173B
embedding_layer_size: 642M
decoder size: 174B
"""
import fire
def pretty_print_large_num(n):
if n < 1e9:
return str(n // 1000000) + "M"
elif n < 1e10:
return f"{round(n / 1000000000, 1):.1f}" + "B"
else:
return str(n // 1000000000) + "B"
# FFN Module
class FFNModuleParamCount():
@classmethod
def count(cls, model_dim, ffn_dim):
return 2 * model_dim * ffn_dim + model_dim + ffn_dim # the latter 2 are bias terms
# Self-attention Module
class SelfAttentionParamCount():
@classmethod
def count(cls, hidden_dim, head_size, num_heads):
q_dim = hidden_dim
k_dim = hidden_dim
v_dim = hidden_dim
q_size = q_dim * head_size + head_size
k_size = k_dim * head_size + head_size
v_size = v_dim * head_size + head_size
print(f"q_size: {q_size}")
print(f"k_size: {k_size}")
print(f"v_size: {v_size}")
out_size = head_size * num_heads * hidden_dim
attn_size = num_heads * (q_size + k_size + v_size) + out_size
print(f"out_size: {out_size}")
print(f"attn_size: {attn_size}")
return attn_size
# Decoder layer
class DecoderLayerParamCount():
def __init__(self, hidden_dim=896, num_heads=16, head_size=None, ffn_dim=None) -> None:
self.hidden_dim = hidden_dim
self.ffn_dim = 4 * hidden_dim if ffn_dim is None else ffn_dim
self.num_heads = num_heads
if head_size is None:
assert hidden_dim % num_heads == 0
self.head_size = int(hidden_dim / num_heads) if head_size is None else head_size
@property
def count(self):
mh_attn_size = SelfAttentionParamCount.count(self.hidden_dim, self.head_size, self.num_heads)
layer_norm_size = self.hidden_dim * 2
ffn_size = FFNModuleParamCount.count(self.hidden_dim, self.ffn_dim)
layer_norm_size = self.hidden_dim * 2
return mh_attn_size + layer_norm_size + ffn_size + layer_norm_size
# Decoder
class DecoderParamCount():
def __init__(self, num_decoder_layers=12, hidden_dim=896, num_heads=16, head_size=None, ffn_dim=None, vocab_size=None, emb_dim=None, sequence_len=2048, use_learned_pos_emb=False) -> None:
self.decoder_layer_param_count = DecoderLayerParamCount(
hidden_dim=hidden_dim,
ffn_dim=ffn_dim,
num_heads=num_heads,
head_size=head_size
)
self.num_decoder_layers = num_decoder_layers
self.vocab_size = vocab_size
self.emb_dim = hidden_dim if emb_dim is None else emb_dim
self.sequence_len = sequence_len
self.use_learned_pos_emb = use_learned_pos_emb
def count(self, include_embedding_param_count=False):
decoder_layer_size = self.decoder_layer_param_count.count
print(f"decoder layer size: {decoder_layer_size}")
count = self.num_decoder_layers * decoder_layer_size
print(f"decoder size (without embeddings): {pretty_print_large_num(count)}")
if include_embedding_param_count:
embedding_layer_size = self.emb_dim * self.vocab_size
if self.use_learned_pos_emb:
pos_emb_size = self.sequence_len * self.emb_dim
embedding_layer_size += pos_emb_size
print(f"embedding_layer_size: {pretty_print_large_num(embedding_layer_size)}")
count += embedding_layer_size
print(f"decoder size: {pretty_print_large_num(count)}")
return count
def main(num_decoder_layers: int, hidden_dim: int, num_heads: int, head_size: int=None, ffn_dim: int=None, vocab_size: int=None, emb_dim: int=None, sequence_len=2048, use_learned_pos_emb: bool=False):
decoder_param_count = DecoderParamCount(
num_decoder_layers=num_decoder_layers,
hidden_dim=hidden_dim,
num_heads=num_heads,
head_size=head_size,
ffn_dim=ffn_dim,
vocab_size=vocab_size,
emb_dim=emb_dim,
sequence_len=sequence_len,
use_learned_pos_emb=use_learned_pos_emb
)
decoder_param_count.count(include_embedding_param_count=True)
if __name__ == "__main__":
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment