Skip to content

Instantly share code, notes, and snippets.

View freckletonj's full-sized avatar
🎯
Focusing

neurallambda freckletonj

🎯
Focusing
View GitHub Profile
@freckletonj
freckletonj / kv_cache_generate.py
Created August 30, 2024 00:53
Generate tokens using past_key_values/kv-cache in transformers
def generate_with_cache(model, model_inputs, max_new_tokens):
''' Use past_key_values for a theoretical speedup. '''
generated_tokens = []
past_key_values = None
next_token = None
input_ids = model_inputs['input_ids']
attention_mask = model_inputs['attention_mask']
for i in range(max_new_tokens):