Last active
April 3, 2025 07:56
-
-
Save luhenry/2d342553409e406cbd50dfff757f9c30 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
## Dependencies: pip install torch torchvision pandas timm transformers | |
import pandas as pd | |
import sys | |
import torch | |
import torch.autograd.profiler as profiler | |
import torch.nn as nn | |
from enum import StrEnum, auto | |
class ModelRun: | |
def __init__(self, model, input_tensor): | |
self.model = model | |
self.input_tensor = input_tensor | |
def __call__(self): | |
if isinstance(self.input_tensor, dict): | |
return self.model(**self.input_tensor) | |
else: | |
return self.model(self.input_tensor) | |
def simple(): | |
# Define a simple model | |
class SimpleModel(nn.Module): | |
def __init__(self): | |
super(SimpleModel, self).__init__() | |
self.fc1 = nn.Linear(10, 20) | |
self.relu = nn.ReLU() | |
self.fc2 = nn.Linear(20, 5) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.relu(x) | |
x = self.fc2(x) | |
return x | |
# Instantiate model and input tensor | |
model = SimpleModel() | |
input_tensor = torch.randn(1, 10) | |
return ModelRun(model, input_tensor) | |
def vision(model_name): | |
import timm | |
batch_size = 1 | |
channels = 3 | |
height = 224 | |
width = 224 | |
if model_name == "mobilenet": | |
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True) | |
else: | |
model = timm.create_model(model_name, pretrained=True) | |
input_tensor = torch.zeros(batch_size, channels, height, width) | |
# Make sure the model is on CPU | |
model = model.to("cpu").eval() | |
return ModelRun(model, input_tensor) | |
def transformer(model_name): | |
import inspect | |
import transformers | |
if model_name.startswith("google/gemma-3"): | |
model = transformers.Gemma3ForCausalLM.from_pretrained(model_name) | |
elif model_name == "google/mobilebert-uncased": | |
model = transformers.MobileBertForSequenceClassification.from_pretrained(model_name) | |
else: | |
model = transformers.AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | |
prompt = "In a single word, what is the capital of France: " | |
model_input = { | |
k: v | |
for k, v | |
in tokenizer(prompt, return_tensors="pt").to("cpu").items() | |
if k in inspect.signature(model.forward).parameters | |
} | |
# Make sure the model is on CPU | |
model = model.to("cpu").eval() | |
return ModelRun(model, model_input) | |
class Model(StrEnum): | |
SIMPLE = auto() | |
# Vision | |
RESNET18 = auto() | |
RESNET50 = auto() | |
MOBILENET = auto() | |
# Transformers | |
BERT = auto() | |
MOBILEBERT = auto() | |
GEMMA_2B = auto() | |
GEMMA_3_1B = auto() | |
GEMMA_3_4B = auto() | |
GEMMA_3_12B = auto() | |
GEMMA_3_27B = auto() | |
GPT2 = auto() | |
LLAMA_2_7B = auto() | |
LLAMA_2_7B_HF = auto() | |
LLAMA_3_2_1B = auto() | |
LLAMA_3_2_3B = auto() | |
TINYLLAMA_1_1 = auto() | |
match Model(sys.argv[1]): | |
case Model.SIMPLE: | |
model = simple() | |
case Model.RESNET18: | |
model = vision("resnet18") | |
case Model.RESNET50: | |
model = vision("resnet50") | |
case Model.MOBILENET: | |
model = vision("mobilenet") | |
case Model.BERT: | |
model = transformer("google-bert/bert-base-uncased") | |
case Model.MOBILEBERT: | |
model = transformer("google/mobilebert-uncased") | |
case Model.GEMMA_2B: | |
model = transformer("google/gemma-2b") | |
case Model.GEMMA_3_1B: | |
model = transformer("google/gemma-3-1b-it") | |
case Model.GEMMA_3_4B: | |
model = transformer("google/gemma-3-4b-it") | |
case Model.GEMMA_3_12B: | |
model = transformer("google/gemma-3-12b-it") | |
case Model.GEMMA_3_27B: | |
model = transformer("google/gemma-3-27b-it") | |
case Model.GPT2: | |
model = transformer("gpt2") | |
case Model.LLAMA_2_7B: | |
model = transformer("meta-llama/Llama-2-7b") | |
case Model.LLAMA_2_7B_HF: | |
model = transformer("meta-llama/Llama-2-7b-hf") | |
case Model.LLAMA_3_2_1B: | |
model = transformer("meta-llama/Llama-3.2-1B") | |
case Model.LLAMA_3_2_3B: | |
model = transformer("meta-llama/Llama-3.2-3B") | |
case Model.TINYLLAMA_1_1: | |
model = transformer("TinyLlama/TinyLlama_v1.1") | |
#from torch.export import export | |
#executorch_model = export(model.model, (model.input_tensor, )) | |
# Profile the forward pass | |
with profiler.profile(use_cpu=True, record_shapes=True) as prof: | |
output = model() | |
# Print the list of executed operators | |
profiler_data = [ | |
{ | |
"Operator": evt.key, | |
"Self CPU (us)": int(evt.self_cpu_time_total), | |
"Total CPU (us)": int(evt.cpu_time_total), | |
"Total CPU Avg (us)": int(evt.cpu_time), | |
"Calls": evt.count, | |
"Shapes": evt.input_shapes, | |
} | |
for evt in prof.key_averages(group_by_input_shape=True) | |
] | |
# Print as CSV | |
df = pd.DataFrame(profiler_data) | |
df.sort_values("Self CPU (us)", ascending=False, inplace=True) | |
df.to_csv(sys.stdout, index=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment