Last active
April 4, 2023 03:16
-
-
Save thakkarparth007/84f38c62a3b2c4fdaeebee5ca54b9cd9 to your computer and use it in GitHub Desktop.
Simple repro for an error with pytorch's SDPA that happens in very specific settings.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Torch version: 2.1.0.dev20230403+cu117 | |
# Cuda: 11.7 | |
# Issue summary: | |
# Python's SDPA function is a means to use flash attention. This function doesn't work on sm_86 under some scenarios: | |
# - if we use bs=1, there's no issue (for most sequence lengths. Found it erroring for seq len 3 though) | |
# - if we use bs>1, then the module throws an error, during loss.backward() | |
# - these both happen when head_dim > 64. In this repro, we're using codegen-2B, which has head_dim=80. | |
# | |
# See this for error log: https://pastebin.com/t2Xdyb0d | |
# | |
# Relevant links: | |
# - https://github.com/HazyResearch/flash-attention/issues/138 | |
# - https://github.com/pytorch/pytorch/pull/91994 | |
# - https://github.com/pytorch/pytorch/pull/94921 | |
# - https://discuss.pytorch.org/t/expected-is-sm80-to-be-true-but-got-false/172572 | |
# | |
# Note that the OG flash attention repo does say that for head_dim > 64, backward needs A100 or H100. | |
# But it was not clear if the limitation applies to pytorch's SDPA wrapper around it. | |
# More so, the error should probably happen consistently, not only for bs>1, right? | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-2B-multi") | |
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-2B-multi", low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to("cuda") | |
def test(bs): | |
input_ids = torch.tensor([0]).repeat([bs, 1024]).to("cuda") | |
model.train() | |
output = model(input_ids=input_ids, labels=input_ids) | |
loss = output.loss | |
loss.backward() | |
test(1) # fine | |
test(2) # fine | |
# patch attn: | |
# codegen's attention stuff is defined at model.transformer.h[i].attn._attn: | |
# https://github.com/huggingface/transformers/blob/41a2f3529c6b56866c317031375ffd3e7b8bea01/src/transformers/models/codegen/modeling_codegen.py#L125 | |
def flash(q, k, v, *args, **kwargs): | |
print(q.shape, k.shape, v.shape) | |
return F.scaled_dot_product_attention(q.to(v.dtype), k.to(v.dtype), v, is_causal=True), None | |
model.transformer.h[-1].attn._attn = flash | |
test(1) # fine | |
test(2) # not fine |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment