Skip to content

Instantly share code, notes, and snippets.

@iamlemec
Created June 8, 2024 21:40
Show Gist options
  • Save iamlemec/3febf59b41b7f32a450fcfcb4be0713c to your computer and use it in GitHub Desktop.
Save iamlemec/3febf59b41b7f32a450fcfcb4be0713c to your computer and use it in GitHub Desktop.
Using KV cache with mixed causal/non-causal attention.
import torch
from transformers.models.roberta import RobertaConfig, RobertaModel, RobertaTokenizer
# load model and tokenizer
tokenizer = RobertaTokenizer.from_pretrained('FacebookAI/roberta-base')
model = RobertaModel.from_pretrained('FacebookAI/roberta-base', is_decoder=True).to('cuda')
# tokenize inputs
text = 'hello world, this is a test'
inputs = tokenizer(text, return_tensors='pt').to('cuda')
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
_, n = input_ids.shape
# construct attention masks
attention_noncausal = torch.ones((n, n)).unsqueeze(0).to('cuda')
attention_causal = torch.tril(attention_noncausal)
# check that default is non-causal
outputs_default = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
outputs_causal = model(input_ids=input_ids, attention_mask=attention_causal).last_hidden_state
outputs_noncausal = model(input_ids=input_ids, attention_mask=attention_noncausal).last_hidden_state
assert(torch.allclose(outputs_default, outputs_causal, atol=1e-5))
# construct mixed attention masks
n1 = 4
n2 = n - n1
print(n1, n2)
attention_mixed = torch.cat([
torch.cat([torch.ones((n1, n1)), torch.zeros((n1, n2))], 1),
torch.cat([torch.ones((n2, n1)), torch.tril(torch.ones((n2, n2)))], 1)
], 0).unsqueeze(0).to('cuda')
print(attention_mixed)
# construct split attention masks
attention_mixed_one = attention_mixed[:, :n1, :n1].clone()
attention_mixed_two = attention_mixed[:, n1:, :].clone()
print(attention_mixed_one)
print(attention_mixed_two)
# evaluate full mixed attention
outputs_mixed = model(input_ids=input_ids, attention_mask=attention_mixed).last_hidden_state
# first batch of split case
return_mixed_one = model(
input_ids=input_ids[:, :n1], attention_mask=attention_mixed_one, use_cache=True
)
cache_mixed_one = return_mixed_one.past_key_values
outputs_mixed_one = return_mixed_one.last_hidden_state
# second batch of split case
return_mixed_two = model(
input_ids=input_ids[:, n1:], attention_mask=attention_mixed_two, past_key_values=cache_mixed_one
)
outputs_mixed_two = return_mixed_two.last_hidden_state
# combine outputs
outputs_mixed_combined = torch.cat([outputs_mixed_one, outputs_mixed_two], 1)
assert(torch.allclose(outputs_mixed, outputs_mixed_combined, atol=1e-5))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment