Skip to content

Instantly share code, notes, and snippets.

@tomaarsen
Created December 7, 2023 20:54
Show Gist options
  • Save tomaarsen/28f213e37797a143238f8bbfe30e3ad8 to your computer and use it in GitHub Desktop.
Save tomaarsen/28f213e37797a143238f8bbfe30e3ad8 to your computer and use it in GitHub Desktop.
Attention Sinks in `transformers` showcase
from transformers import AutoTokenizer, SinkCache, LlamaForCausalLM, TextStreamer
import torch
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
cache = SinkCache(window_length=256, num_sink_tokens=4)
streamer = TextStreamer(tokenizer)
gen_out = model.generate(**inputs, do_sample=False, min_new_tokens=10000, max_new_tokens=10000, past_key_values=cache, streamer=streamer)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment