Skip to content

Instantly share code, notes, and snippets.

@SunMarc
Last active September 16, 2024 23:01
Show Gist options
  • Save SunMarc/2d863109027c083e5ea4d1c9b4991388 to your computer and use it in GitHub Desktop.
Save SunMarc/2d863109027c083e5ea4d1c9b4991388 to your computer and use it in GitHub Desktop.
`transformers` + `torchao` quantization + `torch.compile` on Llama3.1 8B
# REQUIRES torchao, torch nightly (or torch 2.5) and transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, TorchAoConfig
from transformers import TextStreamer
import torch
from tqdm import tqdm
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
torch.set_float32_matmul_precision('high')
# Other configuration options
DEVICE = "cuda:0"
NUM_RUNS = 10
MAX_NEW_TOKENS = 500
# Load the model and prepare generate args
repo_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# Set the quantization config
# You can choose between int4_weight_only (4-bit), int8_weight_only (8-bit) and int8_dynamic_activation_int8_weight (8-bit)
# group_size is only for int4_weight_only and needs to be one of [32,64,128,256]
# quantization_config = TorchAoConfig(quant_type="int4_weight_only", group_size=128)
# Loading the quantized model takes 6218 MB
model = AutoModelForCausalLM.from_pretrained(repo_id,
torch_dtype=torch.bfloat16,
device_map=DEVICE
)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
messages = [
{"role": "user", "content": "Write a story: "},
]
model_inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(DEVICE)
generate_kwargs = {
"max_new_tokens": MAX_NEW_TOKENS,
"do_sample": True,
"temperature": 0.2,
"eos_token_id": -1 # forces the generation of `max_new_tokens`
}
# Warmup
print("Warming up...")
for _ in range(2):
gen_out = model.generate(**model_inputs, **generate_kwargs)
print("Done!")
# Measure OR Stream
def measure_generate(model, model_inputs, generate_kwargs):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.reset_peak_memory_stats(DEVICE)
torch.cuda.empty_cache()
torch.cuda.synchronize()
start_event.record()
for _ in tqdm(range(NUM_RUNS)):
gen_out = model.generate(**model_inputs, **generate_kwargs)
end_event.record()
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated(DEVICE)
print("Max memory (MB): ", max_memory * 1e-6)
print("Throughput (tokens/sec): ", (NUM_RUNS * MAX_NEW_TOKENS) / (start_event.elapsed_time(end_event) * 1.0e-3))
def stream_generate(model, model_inputs, generate_kwargs):
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_out = model.generate(**model_inputs,streamer=streamer, **generate_kwargs)
stream_generate(model, model_inputs, generate_kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment