Created
January 3, 2025 18:53
-
-
Save jerryzh168/01d367aaf44dbbbfd4068a4a10a00061 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchao | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig | |
# benchmark the performance | |
import torch.utils.benchmark as benchmark | |
def benchmark_fn(f, *args, **kwargs): | |
# Manual warmup | |
for _ in range(5): | |
f(*args, **kwargs) | |
t0 = benchmark.Timer( | |
stmt="f(*args, **kwargs)", | |
globals={"args": args, "kwargs": kwargs, "f": f}, | |
num_threads=torch.get_num_threads(), | |
) | |
return f"{(t0.blocked_autorange().mean):.3f}" | |
# model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
model_name = "meta-llama/Meta-Llama-3-8B" | |
quantization_config = TorchAoConfig("autoquant") | |
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
input_text = "What are we having for dinner?" | |
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") | |
# quantized_model = torch.compile(quantized_model, mode="max-autotune") | |
# torch.backends.mha.set_fastpath_enabled(False) | |
# quantized_model = torchao.autoquant(quantized_model, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) | |
# from torchao.quantization import int4_weight_only | |
# torchao.quantize_(quantized_model, int4_weight_only()) | |
MAX_NEW_TOKENS = 1000 | |
quantized_model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static") | |
if hasattr(quantized_model, "finalize_autoquant"): | |
print("finalizing autoquant") | |
quantized_model.finalize_autoquant() | |
quantized_model.forward = torch.compile(quantized_model.forward, mode="reduce-overhead", fullgraph=True) | |
from torchao.quantization.autoquant import AUTOQUANT_CACHE | |
import pickle | |
with open("quantization-cache.pkl", "wb") as f: | |
pickle.dump(AUTOQUANT_CACHE, f) | |
print("autoquant model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) | |
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16) | |
bf16_model = torch.compile(bf16_model, mode="max-autotune") | |
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) | |
# breakpoint() | |
save_to = "jerryzh168/llama3-8b-autoquant" | |
quantized_model.push_to_hub(save_to, safe_serialization=False) | |
tokenizer.push_to_hub(save_to) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment