Skip to content

Instantly share code, notes, and snippets.

@Quentin-Anthony
Created November 3, 2023 21:34
Show Gist options
  • Save Quentin-Anthony/7f8b7d780cd6f7d1c94eba135da179f9 to your computer and use it in GitHub Desktop.
Save Quentin-Anthony/7f8b7d780cd6f7d1c94eba135da179f9 to your computer and use it in GitHub Desktop.
Transformer FLOPs with Dense/MoE
import argparse
import math
# Helper function to pretty-print message sizes
def convert_flops(params):
if params == 0:
return "0"
size_name = ("", "KFLOPs", "MFLOPs", "GFLOPs", "TFLOPs", "PFLOPs", "EFLOPs", "ZFLOPs", "YFLOPs")
i = int(math.floor(math.log(params, 1000)))
p = math.pow(1000, i)
s = round(params / p, 2)
return "%s %s" % (s, size_name[i])
def config_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--vocab-size", "-v",
type=int,
default=51200,
help='Size of the vocab')
parser.add_argument("--hidden-size", "-hs",
type=int,
default=6144,
help='Dimension of the model\'s hidden size')
parser.add_argument("--num-attention-heads", "-a",
type=int,
default=64,
help='Number of attention heads used in model')
parser.add_argument("--sequence-length", "-s",
type=int,
default=2048,
help='Sequence length used for training')
parser.add_argument("--num-layers", "-l",
type=int,
default=44,
help='Number of transformer layers used in model')
parser.add_argument("--moe",
action="store_true",
help='Whether our model is MoE')
parser.add_argument("--num-experts", "-e",
type=int,
default=128,
help='Number of experts for MoE')
parser.add_argument("--expert-interval", "-ei",
type=int,
default=2,
help='Expert interval for MoE')
parser.add_argument("--topk", "-t",
type=int,
default=1,
help='Top k routing for MoE')
parser.add_argument("--batch-size", "-b",
type=int,
default=1,
help='Global batch size in units of samples')
parser.add_argument("--tokens",
type=int,
default=300e9,
help='Number of tokens you are training over')
parser.add_argument("--no-checkpoint-activations", "-ca",
action='store_false',
help='Whether Megatron-style activation checkpointing is being used',
dest='checkpoint_activations')
return parser
# calculates the flops of a model given its hparams
def calc_params(args):
assert args.topk <= args.num_experts, "You cannot route to more experts than you have!"
assert args.num_layers % args.expert_interval == 0, "Require for simplicity that we don't have hanging dense layers"
# An A_(m x k) X B_(k x n) matrix multiplication requires 2m x k x n FLOPs (factor of 2 needed to account for multiplies and adds)
# determine the flops factor.
# If no activation checkpointing/recomputation, 1 for fwd and 2 for bwd (because we need to calculate the grads with respect to both the input and weight tensors).
# If activation checkpointing/recomputation, add 1 more for the next full forward pass
iter_factor = 3
if args.checkpoint_activations:
iter_factor += 1
qkv_flops = iter_factor * 6 * args.num_layers * args.tokens * args.hidden_size * args.hidden_size
attention_matrix_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size
attention_over_values_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size
linear_projection_flops = iter_factor * 2 * args.num_layers * args.tokens * args.hidden_size * args.hidden_size
ffn_flops = iter_factor * 16 * args.num_layers * args.tokens * args.hidden_size * args.hidden_size
# no activation checkpointing for embeddings
embedding_flops = 6 * args.tokens * args.hidden_size * args.vocab_size
if args.moe and args.topk > 1:
ffn_flops += ffn_flops * args.topk / args.expert_interval
if args.moe:
gating_flops = 2 * args.num_experts * args.hidden_size / args.expert_interval
total_flops = qkv_flops + attention_matrix_flops + attention_over_values_flops + linear_projection_flops + ffn_flops + embedding_flops
if args.moe:
total_flops += gating_flops
print(f'Calculating number of FLOPs with training configuration: {vars(args)}\n')
print(f'QKV FLOPs: {convert_flops(qkv_flops)}')
print(f'Attention Matrix FLOPs: {convert_flops(attention_matrix_flops)}')
print(f'Attention Over Values FLOPs: {convert_flops(attention_over_values_flops)}')
print(f'Linear Projection FLOPs: {convert_flops(linear_projection_flops)}')
print(f'FFN FLOPs: {convert_flops(ffn_flops)}')
print(f'Embedding FLOPs: {convert_flops(embedding_flops)}')
if args.moe:
print(f'Gating FLOPs: {convert_flops(gating_flops)}')
print(f'Total FLOPs for the Model: {convert_flops(total_flops)}')
if __name__ == "__main__":
print('\nExample with Fairseq-MoE 15B: python calc_transformer_flops.py -l 12 -hs 768 --moe -e 512')
print('Example with GPT-3 175B: python calc_transformer_flops.py -l 96 -hs 12288')
args = config_parser().parse_args()
calc_params(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment