Skip to content

Instantly share code, notes, and snippets.

@ArthurZucker
Created March 1, 2024 03:03
Show Gist options
  • Save ArthurZucker/5dc54a3fb443e979fac437e5df7c800b to your computer and use it in GitHub Desktop.
Save ArthurZucker/5dc54a3fb443e979fac437e5df7c800b to your computer and use it in GitHub Desktop.
I don't pass the positions so prompts have the same shape
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
import torch
from typing import Optional
import time
import os
os.environ["TOKENIZERS_PARALLELISM"] = "1"
device = "cuda:1"
torch.set_float32_matmul_precision('high')
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to(device)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
NUM_TOKENS_TO_GENERATE = 100
torch_device = "cuda:1"
def decode_one_tokens(model, cur_token, input_pos, cache_position):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
)[0]
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
model._setup_cache(StaticCache, 2, max_cache_len=4096)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
generated_ids[:, seq_length] = next_token[:, 0]
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length], device=torch_device)
for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
cache_position += 1
generated_ids[:, cache_position] = next_token.int()
for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
cache_position += 1
generated_ids[:, cache_position] = next_token.int()
print(generated_ids)
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment