Created
March 25, 2025 20:52
-
-
Save jerryzh168/0e749d0dab40e2a62a7f2e48639f77b5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TorchAoConfig | |
| from PIL import Image | |
| import requests | |
| import torch.utils.benchmark as benchmark | |
| from torchao.utils import benchmark_model | |
| def benchmark_fn(f, *args, **kwargs): | |
| # Manual warmup | |
| for _ in range(2): | |
| f(*args, **kwargs) | |
| t0 = benchmark.Timer( | |
| stmt="f(*args, **kwargs)", | |
| globals={"args": args, "kwargs": kwargs, "f": f}, | |
| num_threads=torch.get_num_threads(), | |
| ) | |
| return f"{(t0.blocked_autorange().mean):.3f}" | |
| model_id = "google/gemma-3-4b-it" | |
| # We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight | |
| # More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques | |
| from torchao.quantization import Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig | |
| quant_config = Int4WeightOnlyConfig(group_size=128) | |
| # quant_config = Int8DynamicActivationInt4WeightConfig() | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| quantized_model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="cuda", quantization_config=quantization_config, torch_dtype="auto") | |
| processor = AutoProcessor.from_pretrained(model_id, padding_side="right") | |
| # save quantized model | |
| # output_dir = "llama3-8b-int4wo-128" | |
| # quantized_model.save_pretrained(output_dir, safe_serialization=False) | |
| # push to hub | |
| save_to = "jerryzh168/gemma3-int4wo" | |
| # save_to = "jerryzh168/gemma3-8da4w" | |
| save_to_gguf = save_to + "-gguf" | |
| quantized_model.push_to_hub(save_to, safe_serialization=False) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": "You are a helpful assistant."}] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"}, | |
| {"type": "text", "text": "Describe this image in detail."} | |
| ] | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, | |
| return_dict=True, return_tensors="pt" | |
| ).to(quantized_model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| # compile the quantized model to get speedup | |
| import torchao | |
| torchao.quantization.utils.recommended_inductor_config_setter() | |
| quantized_model = torch.compile(quantized_model, mode="max-autotune") | |
| MAX_NEW_TOKENS = 1000 | |
| with torch.inference_mode(): | |
| generation = quantized_model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False) | |
| generation = generation[0][input_len:] | |
| decoded = processor.decode(generation, skip_special_tokens=True) | |
| print(decoded) | |
| with torch.inference_mode(): | |
| print("int4wo model:", benchmark_fn(quantized_model.generate, **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)) | |
| bf16_model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="cuda") | |
| bf16_model = torch.compile(bf16_model, mode="max-autotune") | |
| with torch.inference_mode(): | |
| print("loaded bf16 model:", benchmark_fn(bf16_model.generate, **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)) | |
| quantized_model = Gemma3ForConditionalGeneration.from_pretrained(save_to, device_map="cuda:0", torch_dtype="auto") | |
| quantized_model = torch.compile(quantized_model, mode="max-autotune") | |
| with torch.inference_mode(): | |
| print("loaded int4wo model:", benchmark_fn(quantized_model.generate, **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment