Last active
September 16, 2024 23:01
-
-
Save SunMarc/2d863109027c083e5ea4d1c9b4991388 to your computer and use it in GitHub Desktop.
`transformers` + `torchao` quantization + `torch.compile` on Llama3.1 8B
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# REQUIRES torchao, torch nightly (or torch 2.5) and transformers | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TorchAoConfig | |
from transformers import TextStreamer | |
import torch | |
from tqdm import tqdm | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) | |
torch.set_float32_matmul_precision('high') | |
# Other configuration options | |
DEVICE = "cuda:0" | |
NUM_RUNS = 10 | |
MAX_NEW_TOKENS = 500 | |
# Load the model and prepare generate args | |
repo_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
# Set the quantization config | |
# You can choose between int4_weight_only (4-bit), int8_weight_only (8-bit) and int8_dynamic_activation_int8_weight (8-bit) | |
# group_size is only for int4_weight_only and needs to be one of [32,64,128,256] | |
# quantization_config = TorchAoConfig(quant_type="int4_weight_only", group_size=128) | |
# Loading the quantized model takes 6218 MB | |
model = AutoModelForCausalLM.from_pretrained(repo_id, | |
torch_dtype=torch.bfloat16, | |
device_map=DEVICE | |
) | |
model.generation_config.cache_implementation = "static" | |
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) | |
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True) | |
messages = [ | |
{"role": "user", "content": "Write a story: "}, | |
] | |
model_inputs = tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt", | |
return_dict=True, | |
).to(DEVICE) | |
generate_kwargs = { | |
"max_new_tokens": MAX_NEW_TOKENS, | |
"do_sample": True, | |
"temperature": 0.2, | |
"eos_token_id": -1 # forces the generation of `max_new_tokens` | |
} | |
# Warmup | |
print("Warming up...") | |
for _ in range(2): | |
gen_out = model.generate(**model_inputs, **generate_kwargs) | |
print("Done!") | |
# Measure OR Stream | |
def measure_generate(model, model_inputs, generate_kwargs): | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
torch.cuda.reset_peak_memory_stats(DEVICE) | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
start_event.record() | |
for _ in tqdm(range(NUM_RUNS)): | |
gen_out = model.generate(**model_inputs, **generate_kwargs) | |
end_event.record() | |
torch.cuda.synchronize() | |
max_memory = torch.cuda.max_memory_allocated(DEVICE) | |
print("Max memory (MB): ", max_memory * 1e-6) | |
print("Throughput (tokens/sec): ", (NUM_RUNS * MAX_NEW_TOKENS) / (start_event.elapsed_time(end_event) * 1.0e-3)) | |
def stream_generate(model, model_inputs, generate_kwargs): | |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
gen_out = model.generate(**model_inputs,streamer=streamer, **generate_kwargs) | |
stream_generate(model, model_inputs, generate_kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment