Skip to content

Instantly share code, notes, and snippets.

@thakkarparth007
Last active April 4, 2023 03:16
Show Gist options
  • Save thakkarparth007/84f38c62a3b2c4fdaeebee5ca54b9cd9 to your computer and use it in GitHub Desktop.
Save thakkarparth007/84f38c62a3b2c4fdaeebee5ca54b9cd9 to your computer and use it in GitHub Desktop.
Simple repro for an error with pytorch's SDPA that happens in very specific settings.
# Torch version: 2.1.0.dev20230403+cu117
# Cuda: 11.7
# Issue summary:
# Python's SDPA function is a means to use flash attention. This function doesn't work on sm_86 under some scenarios:
# - if we use bs=1, there's no issue (for most sequence lengths. Found it erroring for seq len 3 though)
# - if we use bs>1, then the module throws an error, during loss.backward()
# - these both happen when head_dim > 64. In this repro, we're using codegen-2B, which has head_dim=80.
#
# See this for error log: https://pastebin.com/t2Xdyb0d
#
# Relevant links:
# - https://github.com/HazyResearch/flash-attention/issues/138
# - https://github.com/pytorch/pytorch/pull/91994
# - https://github.com/pytorch/pytorch/pull/94921
# - https://discuss.pytorch.org/t/expected-is-sm80-to-be-true-but-got-false/172572
#
# Note that the OG flash attention repo does say that for head_dim > 64, backward needs A100 or H100.
# But it was not clear if the limitation applies to pytorch's SDPA wrapper around it.
# More so, the error should probably happen consistently, not only for bs>1, right?
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-2B-multi")
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-2B-multi", low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to("cuda")
def test(bs):
input_ids = torch.tensor([0]).repeat([bs, 1024]).to("cuda")
model.train()
output = model(input_ids=input_ids, labels=input_ids)
loss = output.loss
loss.backward()
test(1) # fine
test(2) # fine
# patch attn:
# codegen's attention stuff is defined at model.transformer.h[i].attn._attn:
# https://github.com/huggingface/transformers/blob/41a2f3529c6b56866c317031375ffd3e7b8bea01/src/transformers/models/codegen/modeling_codegen.py#L125
def flash(q, k, v, *args, **kwargs):
print(q.shape, k.shape, v.shape)
return F.scaled_dot_product_attention(q.to(v.dtype), k.to(v.dtype), v, is_causal=True), None
model.transformer.h[-1].attn._attn = flash
test(1) # fine
test(2) # not fine
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment