Skip to content

Instantly share code, notes, and snippets.

@jerryzh168
Created January 3, 2025 18:53
Show Gist options
  • Save jerryzh168/01d367aaf44dbbbfd4068a4a10a00061 to your computer and use it in GitHub Desktop.
Save jerryzh168/01d367aaf44dbbbfd4068a4a10a00061 to your computer and use it in GitHub Desktop.
import torch
import torchao
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
# benchmark the performance
import torch.utils.benchmark as benchmark
def benchmark_fn(f, *args, **kwargs):
# Manual warmup
for _ in range(5):
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
# model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = TorchAoConfig("autoquant")
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# quantized_model = torch.compile(quantized_model, mode="max-autotune")
# torch.backends.mha.set_fastpath_enabled(False)
# quantized_model = torchao.autoquant(quantized_model, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
# from torchao.quantization import int4_weight_only
# torchao.quantize_(quantized_model, int4_weight_only())
MAX_NEW_TOKENS = 1000
quantized_model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")
if hasattr(quantized_model, "finalize_autoquant"):
print("finalizing autoquant")
quantized_model.finalize_autoquant()
quantized_model.forward = torch.compile(quantized_model.forward, mode="reduce-overhead", fullgraph=True)
from torchao.quantization.autoquant import AUTOQUANT_CACHE
import pickle
with open("quantization-cache.pkl", "wb") as f:
pickle.dump(AUTOQUANT_CACHE, f)
print("autoquant model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
bf16_model = torch.compile(bf16_model, mode="max-autotune")
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
# breakpoint()
save_to = "jerryzh168/llama3-8b-autoquant"
quantized_model.push_to_hub(save_to, safe_serialization=False)
tokenizer.push_to_hub(save_to)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment