Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active December 19, 2023 22:07
Show Gist options
  • Save Birch-san/4315701264b72bb72e8eac5a529ee93a to your computer and use it in GitHub Desktop.
Save Birch-san/4315701264b72bb72e8eac5a529ee93a to your computer and use it in GitHub Desktop.
FlashAttnProcessor
import torch
from typing import Optional
from flash_attn import flash_attn_func
from diffusers.models.attention import Attention
class FlashAttnProcessor:
r"""
Processor for implementing memory efficient attention using flash_attn.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
assert attention_mask is None, 'flash_attn does not implement support for attention masks'
hidden_states = flash_attn_func(
query, key, value, dropout_p=0.0, causal=False
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = hidden_states.flatten(-2)
out_proj, dropout = attn.to_out
hidden_states = out_proj(hidden_states)
hidden_states = dropout(hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
if attn.rescale_output_factor != 1:
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
@okaris
Copy link

okaris commented Dec 19, 2023

Thank you for sharing this @Birch-san why doesn’t FlashAttnQKVPackedProcessor work with cross-attention?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment