Skip to content

Instantly share code, notes, and snippets.

@luhenry
Last active April 3, 2025 07:56
Show Gist options
  • Save luhenry/2d342553409e406cbd50dfff757f9c30 to your computer and use it in GitHub Desktop.
Save luhenry/2d342553409e406cbd50dfff757f9c30 to your computer and use it in GitHub Desktop.
#!/usr/bin/python3
## Dependencies: pip install torch torchvision pandas timm transformers
import pandas as pd
import sys
import torch
import torch.autograd.profiler as profiler
import torch.nn as nn
from enum import StrEnum, auto
class ModelRun:
def __init__(self, model, input_tensor):
self.model = model
self.input_tensor = input_tensor
def __call__(self):
if isinstance(self.input_tensor, dict):
return self.model(**self.input_tensor)
else:
return self.model(self.input_tensor)
def simple():
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(20, 5)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Instantiate model and input tensor
model = SimpleModel()
input_tensor = torch.randn(1, 10)
return ModelRun(model, input_tensor)
def vision(model_name):
import timm
batch_size = 1
channels = 3
height = 224
width = 224
if model_name == "mobilenet":
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
else:
model = timm.create_model(model_name, pretrained=True)
input_tensor = torch.zeros(batch_size, channels, height, width)
# Make sure the model is on CPU
model = model.to("cpu").eval()
return ModelRun(model, input_tensor)
def transformer(model_name):
import inspect
import transformers
if model_name.startswith("google/gemma-3"):
model = transformers.Gemma3ForCausalLM.from_pretrained(model_name)
elif model_name == "google/mobilebert-uncased":
model = transformers.MobileBertForSequenceClassification.from_pretrained(model_name)
else:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
prompt = "In a single word, what is the capital of France: "
model_input = {
k: v
for k, v
in tokenizer(prompt, return_tensors="pt").to("cpu").items()
if k in inspect.signature(model.forward).parameters
}
# Make sure the model is on CPU
model = model.to("cpu").eval()
return ModelRun(model, model_input)
class Model(StrEnum):
SIMPLE = auto()
# Vision
RESNET18 = auto()
RESNET50 = auto()
MOBILENET = auto()
# Transformers
BERT = auto()
MOBILEBERT = auto()
GEMMA_2B = auto()
GEMMA_3_1B = auto()
GEMMA_3_4B = auto()
GEMMA_3_12B = auto()
GEMMA_3_27B = auto()
GPT2 = auto()
LLAMA_2_7B = auto()
LLAMA_2_7B_HF = auto()
LLAMA_3_2_1B = auto()
LLAMA_3_2_3B = auto()
TINYLLAMA_1_1 = auto()
match Model(sys.argv[1]):
case Model.SIMPLE:
model = simple()
case Model.RESNET18:
model = vision("resnet18")
case Model.RESNET50:
model = vision("resnet50")
case Model.MOBILENET:
model = vision("mobilenet")
case Model.BERT:
model = transformer("google-bert/bert-base-uncased")
case Model.MOBILEBERT:
model = transformer("google/mobilebert-uncased")
case Model.GEMMA_2B:
model = transformer("google/gemma-2b")
case Model.GEMMA_3_1B:
model = transformer("google/gemma-3-1b-it")
case Model.GEMMA_3_4B:
model = transformer("google/gemma-3-4b-it")
case Model.GEMMA_3_12B:
model = transformer("google/gemma-3-12b-it")
case Model.GEMMA_3_27B:
model = transformer("google/gemma-3-27b-it")
case Model.GPT2:
model = transformer("gpt2")
case Model.LLAMA_2_7B:
model = transformer("meta-llama/Llama-2-7b")
case Model.LLAMA_2_7B_HF:
model = transformer("meta-llama/Llama-2-7b-hf")
case Model.LLAMA_3_2_1B:
model = transformer("meta-llama/Llama-3.2-1B")
case Model.LLAMA_3_2_3B:
model = transformer("meta-llama/Llama-3.2-3B")
case Model.TINYLLAMA_1_1:
model = transformer("TinyLlama/TinyLlama_v1.1")
#from torch.export import export
#executorch_model = export(model.model, (model.input_tensor, ))
# Profile the forward pass
with profiler.profile(use_cpu=True, record_shapes=True) as prof:
output = model()
# Print the list of executed operators
profiler_data = [
{
"Operator": evt.key,
"Self CPU (us)": int(evt.self_cpu_time_total),
"Total CPU (us)": int(evt.cpu_time_total),
"Total CPU Avg (us)": int(evt.cpu_time),
"Calls": evt.count,
"Shapes": evt.input_shapes,
}
for evt in prof.key_averages(group_by_input_shape=True)
]
# Print as CSV
df = pd.DataFrame(profiler_data)
df.sort_values("Self CPU (us)", ascending=False, inplace=True)
df.to_csv(sys.stdout, index=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment