Last active
July 27, 2025 19:21
-
-
Save finbarrtimbers/871aa9ff786418412b1108dab7c9ed25 to your computer and use it in GitHub Desktop.
count model flops
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def calculate_model_usage_per_token(model_path: str) -> int: | |
| """ | |
| Calculate actual FLOPs per token for a transformer model using torch FlopCounterMode. | |
| Args: | |
| model_path: Path to the actual model for precise measurement | |
| Returns: | |
| FLOPs per token as integer. | |
| """ | |
| model = transformers.AutoModelForCausalLM.from_pretrained( | |
| model_path, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True | |
| ) | |
| # Create a single token input | |
| input_ids = torch.tensor([[1]], device=model.device) # Single token | |
| model.eval() # Set model to evaluation mode for consistent FLOPs counting | |
| flop_counter = torch.utils.flop_counter.FlopCounterMode(display=False, depth=None) | |
| with flop_counter: | |
| model(input_ids) | |
| return flop_counter.get_total_flops() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment