Skip to content

Instantly share code, notes, and snippets.

@jerryzh168
Created March 25, 2025 20:52
Show Gist options
  • Select an option

  • Save jerryzh168/0e749d0dab40e2a62a7f2e48639f77b5 to your computer and use it in GitHub Desktop.

Select an option

Save jerryzh168/0e749d0dab40e2a62a7f2e48639f77b5 to your computer and use it in GitHub Desktop.
import torch
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TorchAoConfig
from PIL import Image
import requests
import torch.utils.benchmark as benchmark
from torchao.utils import benchmark_model
def benchmark_fn(f, *args, **kwargs):
# Manual warmup
for _ in range(2):
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
model_id = "google/gemma-3-4b-it"
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
from torchao.quantization import Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig
quant_config = Int4WeightOnlyConfig(group_size=128)
# quant_config = Int8DynamicActivationInt4WeightConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="cuda", quantization_config=quantization_config, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(model_id, padding_side="right")
# save quantized model
# output_dir = "llama3-8b-int4wo-128"
# quantized_model.save_pretrained(output_dir, safe_serialization=False)
# push to hub
save_to = "jerryzh168/gemma3-int4wo"
# save_to = "jerryzh168/gemma3-8da4w"
save_to_gguf = save_to + "-gguf"
quantized_model.push_to_hub(save_to, safe_serialization=False)
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "Describe this image in detail."}
]
}
]
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(quantized_model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
# compile the quantized model to get speedup
import torchao
torchao.quantization.utils.recommended_inductor_config_setter()
quantized_model = torch.compile(quantized_model, mode="max-autotune")
MAX_NEW_TOKENS = 1000
with torch.inference_mode():
generation = quantized_model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
with torch.inference_mode():
print("int4wo model:", benchmark_fn(quantized_model.generate, **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False))
bf16_model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="cuda")
bf16_model = torch.compile(bf16_model, mode="max-autotune")
with torch.inference_mode():
print("loaded bf16 model:", benchmark_fn(bf16_model.generate, **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False))
quantized_model = Gemma3ForConditionalGeneration.from_pretrained(save_to, device_map="cuda:0", torch_dtype="auto")
quantized_model = torch.compile(quantized_model, mode="max-autotune")
with torch.inference_mode():
print("loaded int4wo model:", benchmark_fn(quantized_model.generate, **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment