-
-
Save sayakpaul/ff715f979793d4d44beb68e5e08ee067 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Paired with a good language model. Thanks! | |
""" | |
import torch | |
from typing import Optional, Tuple | |
from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen | |
try: | |
from kernels import get_kernel | |
_k = get_kernel("kernels-community/vllm-flash-attn3") | |
_flash_attn_func = _k.flash_attn_func | |
except Exception as e: | |
_flash_attn_func = None | |
_kernels_err = e | |
def _ensure_fa3_available(): | |
if _flash_attn_func is None: | |
raise ImportError( | |
"FlashAttention-3 via Hugging Face `kernels` is required. " | |
"Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n" | |
f"{_kernels_err}" | |
) | |
@torch.library.custom_op("flash::flash_attn_func", mutates_args=()) | |
def flash_attn_func( | |
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False | |
) -> torch.Tensor: | |
outputs, lse = _flash_attn_func(q, k, v, causal=causal) | |
return outputs | |
@flash_attn_func.register_fake | |
def _(q, k, v, **kwargs): | |
# two outputs: | |
# 1. output: (batch, seq_len, num_heads, head_dim) | |
# 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32 | |
meta_q = torch.empty_like(q).contiguous() | |
return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32) | |
class QwenDoubleStreamAttnProcessorFA3: | |
""" | |
FA3-based attention processor for Qwen double-stream architecture. | |
Computes joint attention over concatenated [text, image] streams using vLLM FlashAttention-3 | |
accessed via Hugging Face `kernels`. | |
Notes / limitations: | |
- General attention masks are not supported here (FA3 path). `is_causal=False` and no arbitrary mask. | |
- Optional windowed attention / sink tokens / softcap can be plumbed through if you use those features. | |
- Expects an available `apply_rotary_emb_qwen` in scope (same as your non-FA3 processor). | |
""" | |
_attention_backend = "fa3" # for parity with your other processors, not used internally | |
def __init__(self): | |
_ensure_fa3_available() | |
@torch.no_grad() | |
def __call__( | |
self, | |
attn, # Attention module with to_q/to_k/to_v/add_*_proj, norms, to_out, to_add_out, and .heads | |
hidden_states: torch.FloatTensor, # (B, S_img, D_model) image stream | |
encoder_hidden_states: torch.FloatTensor = None, # (B, S_txt, D_model) text stream | |
encoder_hidden_states_mask: torch.FloatTensor = None, # unused in FA3 path | |
attention_mask: Optional[torch.FloatTensor] = None, # unused in FA3 path | |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # (img_freqs, txt_freqs) | |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: | |
if encoder_hidden_states is None: | |
raise ValueError("QwenDoubleStreamAttnProcessorFA3 requires encoder_hidden_states (text stream).") | |
if attention_mask is not None: | |
# FA3 kernel path here does not consume arbitrary masks; fail fast to avoid silent correctness issues. | |
raise NotImplementedError("attention_mask is not supported in this FA3 implementation.") | |
_ensure_fa3_available() | |
B, S_img, _ = hidden_states.shape | |
S_txt = encoder_hidden_states.shape[1] | |
# ---- QKV projections (image/sample stream) ---- | |
img_q = attn.to_q(hidden_states) # (B, S_img, D) | |
img_k = attn.to_k(hidden_states) | |
img_v = attn.to_v(hidden_states) | |
# ---- QKV projections (text/context stream) ---- | |
txt_q = attn.add_q_proj(encoder_hidden_states) # (B, S_txt, D) | |
txt_k = attn.add_k_proj(encoder_hidden_states) | |
txt_v = attn.add_v_proj(encoder_hidden_states) | |
# ---- Reshape to (B, S, H, D_h) ---- | |
H = attn.heads | |
img_q = img_q.unflatten(-1, (H, -1)) | |
img_k = img_k.unflatten(-1, (H, -1)) | |
img_v = img_v.unflatten(-1, (H, -1)) | |
txt_q = txt_q.unflatten(-1, (H, -1)) | |
txt_k = txt_k.unflatten(-1, (H, -1)) | |
txt_v = txt_v.unflatten(-1, (H, -1)) | |
# ---- Q/K normalization (per your module contract) ---- | |
if getattr(attn, "norm_q", None) is not None: | |
img_q = attn.norm_q(img_q) | |
if getattr(attn, "norm_k", None) is not None: | |
img_k = attn.norm_k(img_k) | |
if getattr(attn, "norm_added_q", None) is not None: | |
txt_q = attn.norm_added_q(txt_q) | |
if getattr(attn, "norm_added_k", None) is not None: | |
txt_k = attn.norm_added_k(txt_k) | |
# ---- RoPE (Qwen variant) ---- | |
if image_rotary_emb is not None: | |
img_freqs, txt_freqs = image_rotary_emb | |
# expects tensors shaped (B, S, H, D_h) | |
img_q = apply_rotary_emb_qwen(img_q, img_freqs, use_real=False) | |
img_k = apply_rotary_emb_qwen(img_k, img_freqs, use_real=False) | |
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs, use_real=False) | |
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs, use_real=False) | |
# ---- Joint attention over [text, image] along sequence axis ---- | |
# Shapes: (B, S_total, H, D_h) | |
q = torch.cat([txt_q, img_q], dim=1) | |
k = torch.cat([txt_k, img_k], dim=1) | |
v = torch.cat([txt_v, img_v], dim=1) | |
# FlashAttention-3 path expects (B, S, H, D_h) and returns (out, softmax_lse) | |
out = flash_attn_func(q, k, v, causal=False) # out: (B, S_total, H, D_h) | |
# ---- Back to (B, S, D_model) ---- | |
out = out.flatten(2, 3).to(q.dtype) | |
# Split back to text / image segments | |
txt_attn_out = out[:, :S_txt, :] | |
img_attn_out = out[:, S_txt:, :] | |
# ---- Output projections ---- | |
img_attn_out = attn.to_out[0](img_attn_out) | |
if len(attn.to_out) > 1: | |
img_attn_out = attn.to_out[1](img_attn_out) # dropout if present | |
txt_attn_out = attn.to_add_out(txt_attn_out) | |
return img_attn_out, txt_attn_out |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from diffusers import DiffusionPipeline | |
from torch._inductor.package import load_package | |
import torch | |
import torch.utils.benchmark as benchmark | |
from functools import partial | |
import argparse | |
import json | |
import os | |
from spaces import aoti_capture | |
from spaces.zero.torch.aoti import drain_module_parameters | |
import os | |
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/fsx/sayak/.cache" | |
CKPT_ID = "Qwen/Qwen-Image" | |
def initialize_pipeline( | |
torch_dtype=torch.bfloat16, compile: str = None, use_fa3: bool = False | |
): | |
pipe = DiffusionPipeline.from_pretrained(CKPT_ID, torch_dtype=torch_dtype) | |
pipe.to("cuda") | |
if use_fa3: | |
from fa3_qwen import QwenDoubleStreamAttnProcessorFA3 | |
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3()) | |
if compile == "full": | |
pipe.transformer.compile() | |
elif compile == "regional": | |
pipe.transformer.compile_repeated_blocks(fullgraph=True) | |
elif compile == "aoti": | |
pipe = aot(pipe, **get_pipe_kwargs(2)) | |
pipe.set_progress_bar_config(disable=True) | |
return pipe | |
def benchmark_fn(func_to_benchmark): | |
t0 = benchmark.Timer( | |
stmt="func_to_benchmark()", | |
globals={"func_to_benchmark": func_to_benchmark}, | |
num_threads=torch.get_num_threads(), | |
) | |
return f"{(t0.blocked_autorange().mean):.3f}" | |
def get_pipe_kwargs(num_inference_steps=50): | |
pipe_kwargs = { | |
"prompt": "A cat holding a sign that says hello world", | |
"true_cfg_scale": 4.0, | |
"num_inference_steps": num_inference_steps, | |
"generator": torch.manual_seed(0), | |
} | |
return pipe_kwargs | |
def run_inference(pipe, pipe_kwargs): | |
_ = pipe(**pipe_kwargs) | |
def main(args): | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats("cuda") | |
torch.cuda.reset_accumulated_memory_stats("cuda") | |
pipe = initialize_pipeline(compile=args.compile, use_fa3=args.use_fa3) | |
pipe_kwargs = get_pipe_kwargs() | |
# warmup | |
for _ in range(2): | |
run_inference(pipe, pipe_kwargs) | |
inference_func = partial(run_inference, pipe, pipe_kwargs) | |
latency = float(benchmark_fn(inference_func)) | |
inference_memory = round(torch.cuda.max_memory_allocated() / (1024**3), 3) | |
image = pipe(**get_pipe_kwargs()).images[0] | |
artifact_dict = {"time": latency, "memory": inference_memory} | |
artifact_dict.update(vars(args)) | |
file_prefix = f"comp@{args.compile}-fa3@{args.use_fa3}" | |
image.save(f"{file_prefix}.png") | |
with open(f"{file_prefix}.json", "w") as f: | |
json.dump(artifact_dict, f) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--compile", type=str, default=None, choices=["full", "regional"]) | |
parser.add_argument("--use_fa3", action="store_true") | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment