Skip to content

Instantly share code, notes, and snippets.

@finbarrtimbers
Last active July 27, 2025 19:21
Show Gist options
  • Select an option

  • Save finbarrtimbers/871aa9ff786418412b1108dab7c9ed25 to your computer and use it in GitHub Desktop.

Select an option

Save finbarrtimbers/871aa9ff786418412b1108dab7c9ed25 to your computer and use it in GitHub Desktop.
count model flops
def calculate_model_usage_per_token(model_path: str) -> int:
"""
Calculate actual FLOPs per token for a transformer model using torch FlopCounterMode.
Args:
model_path: Path to the actual model for precise measurement
Returns:
FLOPs per token as integer.
"""
model = transformers.AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True
)
# Create a single token input
input_ids = torch.tensor([[1]], device=model.device) # Single token
model.eval() # Set model to evaluation mode for consistent FLOPs counting
flop_counter = torch.utils.flop_counter.FlopCounterMode(display=False, depth=None)
with flop_counter:
model(input_ids)
return flop_counter.get_total_flops()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment