Skip to content

Instantly share code, notes, and snippets.

@Cyrilvallez
Created December 3, 2024 10:15
Show Gist options
  • Save Cyrilvallez/aa8fc622d698f1425615a6651d8ff3a9 to your computer and use it in GitHub Desktop.
Save Cyrilvallez/aa8fc622d698f1425615a6651d8ff3a9 to your computer and use it in GitHub Desktop.
Automatic compilation test
from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig
import torch
import time
import warnings
warnings.filterwarnings("ignore")
device = 1
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
model.generation_config.temperature = 1.0
model.generation_config.top_p = 1.0
model.generation_config.pad_token_id = model.generation_config.eos_token_id
sequence = "Hey what's the plan"
inputs = tokenizer.encode(sequence, return_tensors='pt').to(device)
mask = torch.ones_like(inputs)
t0 = time.time()
out = model.generate(inputs, attention_mask=mask, do_sample=False, max_new_tokens=500)
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'Without static cache and compile: {dt:.3f} s')
t0 = time.time()
out = model.generate(inputs, attention_mask=mask, do_sample=False, max_new_tokens=500, cache_implementation="static")
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'Compiling default config: {dt:.3f} s')
t0 = time.time()
out = model.generate(inputs, attention_mask=mask, do_sample=False, max_new_tokens=500, cache_implementation="static")
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'Using compiled graph: {dt:.3f} s')
t0 = time.time()
out = model.generate(inputs, attention_mask=mask, do_sample=False, max_new_tokens=500, cache_implementation="static",
compile_config=CompileConfig(dynamic=True))
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'Compiling new config: {dt:.3f} s')
t0 = time.time()
out = model.generate(inputs, attention_mask=mask, do_sample=False, max_new_tokens=500, cache_implementation="static",
compile_config=CompileConfig(dynamic=True))
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'Using compiled new graph: {dt:.3f} s')
t0 = time.time()
out = model.generate(inputs, attention_mask=mask, do_sample=False, max_new_tokens=500, cache_implementation="static")
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'Back to 1st config and graph: {dt:.3f} s')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment