Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created February 19, 2025 17:46
Show Gist options
  • Save vwxyzjn/d0188c2d433a2a44f995f6a8ebab0d0b to your computer and use it in GitHub Desktop.
Save vwxyzjn/d0188c2d433a2a44f995f6a8ebab0d0b to your computer and use it in GitHub Desktop.
import argparse
import numpy as np
p = 100 # padding token id
o = 1 # observation (prompt / input ids)
a = 2 # action (response ids)
queries = [
[p, p, o, o, o],
[p, o, o, o, o],
[p, p, p, o, o],
[o, o, o, o, o],
]
responses = [
[a, p, p, p ,p],
[a, a, p, p, p],
[a, p, p, p, p],
[a, a, a, a, a],
]
query_responses = [q + r for q, r in zip(queries, responses)]
pack_length = 13
packed_query_responses = np.array([
[o, o, o, a, o, o, o, o, a, a, o, o, a],
[o, o, o, o, o, a, a, a, a, a, p, p, p]
])
packed_attention_masks = np.array(
[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]],
[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM2-135M")
parser.add_argument("--attn_implementation", type=str, default="sdpa")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--torch_dtype", type=str, default="float32")
args = parser.parse_args()
pack_length = 13
from transformers import AutoModelForCausalLM
import torch
if args.torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif args.torch_dtype == "float16":
torch_dtype = torch.float16
elif args.torch_dtype == "float32":
torch_dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=torch_dtype,
attn_implementation=args.attn_implementation,
)
device = torch.device(args.device)
model = model.to(device)
s = model.forward(
input_ids=torch.LongTensor(packed_query_responses).to(device),
attention_mask=torch.tensor(packed_attention_masks).unsqueeze(1).bool().to(device),
)
s2 = model.forward(
input_ids=torch.LongTensor(query_responses).to(device),
attention_mask=torch.LongTensor(query_responses).to(device) != p,
)
# test packed logits should be the same as raw logits
print("diff: ", (s.logits[0, 12] - s2.logits[2, 5]).abs().sum().item())
print("diff: ", (s.logits[1, 9] - s2.logits[3, 9]).abs().sum().item())
# torch.testing.assert_close(s.logits[0, 12], s2.logits[2, 5], atol=1e-2, rtol=1e-2)
# print(s.logits[0, 12], "\n", s2.logits[2, 5])
# test last sequence's logits should be the same
# torch.testing.assert_close(s.logits[1, 9], s2.logits[3, 9], atol=1e-2, rtol=1e-2)
# print(s.logits[1, 9], "\n", s2.logits[3, 9])
"""
python x7.py --device cpu --torch_dtype float32 --attn_implementation sdpa
# baseline: this works as expected
# diff: 2.8737666606903076
# diff: 0.4484243392944336
python x7.py --device cpu --torch_dtype bfloat16 --attn_implementation sdpa
# bfloat16: the diff is too large
# diff: 125952.0
# diff: 0.96875
python x7.py --device cuda --torch_dtype bfloat16 --attn_implementation sdpa
# ValueError: AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor.
python x7.py --device cuda --torch_dtype bfloat16 --attn_implementation sdpa
# After commenting out `causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)`
# diff: 33024.0
# diff: 120320.0
python x7.py --device cuda --torch_dtype float32 --attn_implementation sdpa
# works as expected
# After commenting out `causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)`
# diff: 1.307295799255371
# diff: 4.921709060668945
python x7.py --device cuda --torch_dtype bfloat16 --attn_implementation flash_attention_2
# RuntimeError: cu_seqlens_q must have shape (batch_size + 1)
python x7.py --device cuda --torch_dtype bfloat16 --attn_implementation flex_attention
# ValueError: LlamaForCausalLM does not support an attention implementation through torch's flex_attention.
python x7.py --device cuda --torch_dtype bfloat16 --attn_implementation flex_attention --model "EleutherAI/pythia-70m"
# a leaf Variable that requires grad is being used in an in-place operation.
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment