Issue title: (working implementation) Fused multi-head attention for arbitrary sequence lengths.
TL;DR you can run multi-head attention (fwd+bwd) faster and with no extra memory – with any sequence length and head dim. We’d love to make it available via apex. We need your advice on how best to do that.
Why should I care? Here's how it compares against the standard multihead attention (blue) for one multi-head attention layer of GPT-J on an RTX 3080Ti.
time, with backward (ms) | peak vram allocated (mb) |
---|---|