Created
March 1, 2024 03:03
-
-
Save ArthurZucker/5dc54a3fb443e979fac437e5df7c800b to your computer and use it in GitHub Desktop.
I don't pass the positions so prompts have the same shape
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache | |
import torch | |
from typing import Optional | |
import time | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "1" | |
device = "cuda:1" | |
torch.set_float32_matmul_precision('high') | |
prompts = [ | |
"Simply put, the theory of relativity states that ", | |
"My favorite all time favorite condiment is ketchup.", | |
] | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right") | |
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to(device) | |
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) | |
NUM_TOKENS_TO_GENERATE = 100 | |
torch_device = "cuda:1" | |
def decode_one_tokens(model, cur_token, input_pos, cache_position): | |
logits = model( | |
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True | |
)[0] | |
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] | |
return new_token | |
batch_size, seq_length = inputs["input_ids"].shape | |
with torch.no_grad(): | |
model._setup_cache(StaticCache, 2, max_cache_len=4096) | |
cache_position = torch.arange(seq_length, device=torch_device) | |
generated_ids = torch.zeros( | |
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device | |
) | |
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) | |
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] | |
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] | |
generated_ids[:, seq_length] = next_token[:, 0] | |
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)) | |
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) | |
cache_position = torch.tensor([seq_length], device=torch_device) | |
for _ in range(1, NUM_TOKENS_TO_GENERATE): | |
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): | |
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) | |
cache_position += 1 | |
generated_ids[:, cache_position] = next_token.int() | |
for _ in range(1, NUM_TOKENS_TO_GENERATE): | |
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): | |
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) | |
cache_position += 1 | |
generated_ids[:, cache_position] = next_token.int() | |
print(generated_ids) | |
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
print(text) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment