Created
June 18, 2025 05:17
-
-
Save a-r-r-o-w/da4f576055411afa82ef8bf1d1334a43 to your computer and use it in GitHub Desktop.
ring attention when you forget to do the rotations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import contextlib | |
import math | |
import pathlib | |
from typing import List, Optional, Tuple | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.profiler._utils | |
import torch._dynamo.config | |
import torch._inductor.config | |
import torch._higher_order_ops.auto_functionalize as af | |
import triton | |
import triton.language as tl | |
from torch.distributed.tensor import DTensor, Shard | |
from torch.profiler import profile, record_function, ProfilerActivity | |
try: | |
from flash_attn import flash_attn_func | |
except: | |
print("Flash Attention 2 not found.") | |
try: | |
from flash_attn_interface import flash_attn_func as flash_attn_3_func | |
except: | |
print("Flash Attention 3 not found.") | |
from diffusers.models.autoencoders import AutoencoderKL | |
from diffusers.image_processor import VaeImageProcessor | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin | |
from diffusers.models.attention import FeedForward | |
from diffusers.models.embeddings import ( | |
get_1d_rotary_pos_embed, | |
TimestepEmbedding, | |
Timesteps, | |
PixArtAlphaTextProjection, | |
) | |
from diffusers.models.cache_utils import CacheMixin | |
from diffusers.models.modeling_utils import ModelMixin | |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
torch._dynamo.config.inline_inbuilt_nn_modules = False | |
torch._inductor.config.coordinate_descent_tuning = True | |
# torch._inductor.config.coordinate_descent_check_all_directions = True | |
torch._inductor.config.coordinate_descent_search_radius = 1 | |
torch._inductor.config.epilogue_fusion = False | |
torch._inductor.config.fx_graph_cache = True | |
# torch._inductor.config.max_fusion_size = 128 | |
# torch._inductor.config.max_pointwise_cat_inputs = 16 | |
# torch._inductor.config.force_pointwise_cat = True | |
torch._inductor.config.aggressive_fusion = True | |
# torch._inductor.config.triton.unique_kernel_names = True | |
af.auto_functionalized_v2._cacheable = True | |
af.auto_functionalized._cacheable = True | |
ROPE_PRECISION = torch.float32 | |
SUPPORTED_GUIDANCE_SCALES = [i / 2 for i in range(41)] # 0, 0.5, 1.0, ..., 20.0 | |
CP_MESH = None | |
# Constants | |
DEVICE = "cuda" | |
DTYPE = torch.bfloat16 | |
MIN_STEPS = 2 | |
MAX_STEPS = 50 | |
SPATIAL_COMPRESSION_RATIO = 8 | |
PATCH_SIZE = 1 | |
IN_CHANNELS = 16 | |
BASE_IMAGE_SEQ_LEN = 256 | |
MAX_IMAGE_SEQ_LEN = 4096 | |
BASE_SHIFT = 0.5 | |
MAX_SHIFT = 1.15 | |
M = (MAX_SHIFT - BASE_SHIFT) / (MAX_IMAGE_SEQ_LEN - BASE_IMAGE_SEQ_LEN) | |
B = BASE_SHIFT - M * BASE_IMAGE_SEQ_LEN | |
# The following parameters are fixed for now | |
BATCH_SIZE = 1 | |
HEIGHT = 1024 | |
WIDTH = 1024 | |
LATENT_HEIGHT = HEIGHT // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2 | |
LATENT_WIDTH = WIDTH // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2 | |
T5_SEQUENCE_LENGTH = 512 | |
GUIDANCE_SCALE = 4.0 | |
def _attention_torch_cudnn(query, key, value): | |
query, key, value = (x.transpose(1, 2) for x in (query, key, value)) | |
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): | |
out = torch.nn.functional.scaled_dot_product_attention(query, key, value) | |
out = out.transpose(1, 2) | |
return out | |
def _attention_flash_attn_2(query, key, value): | |
return flash_attn_func(query, key, value) | |
# For fullgraph=True tracing to be compatible | |
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") | |
def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
out, lse = flash_attn_3_func(query, key, value) | |
return out | |
@torch.library.register_fake("flash_attn_3::_flash_attn_forward") | |
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
return torch.empty_like(query) | |
def _attention_flash_attn_3(query, key, value): | |
out = _wrapped_flash_attn_3(query, key, value) | |
return out | |
ATTENTION_OP = _attention_torch_cudnn | |
class EquipartitionSharder: | |
@classmethod | |
@torch.compiler.disable | |
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
assert tensor.size()[dim] % mesh.size() == 0 | |
return tensor.chunk(mesh.size(), dim=dim)[mesh.get_local_rank()] | |
@classmethod | |
@torch.compiler.disable | |
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
tensor = tensor.contiguous() | |
result = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor() | |
return result | |
@torch.compile | |
def pointwise_add3_silu(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: | |
return torch.nn.functional.silu(x + y + z) | |
class AdaLayerNormContinuous(nn.Module): | |
def __init__( | |
self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine=True, eps=1e-5, bias=True | |
): | |
super().__init__() | |
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) | |
self.norm = torch.nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
emb = self.linear(emb) | |
scale, shift = emb.unsqueeze(1).chunk(2, dim=-1) | |
norm_x = self.norm(x) | |
x = torch.addcmul(shift, norm_x, 1 + scale) | |
return x | |
class AdaLayerNormZeroSingle(nn.Module): | |
def __init__(self, embedding_dim: int, bias=True): | |
super().__init__() | |
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) | |
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None): | |
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) | |
norm_x = self.norm(x) | |
x = torch.addcmul(shift_msa, norm_x, 1 + scale_msa) | |
return x, gate_msa | |
class AdaLayerNormZero(nn.Module): | |
def __init__(self, embedding_dim: int, bias=True): | |
super().__init__() | |
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) | |
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor): | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) | |
norm_x = self.norm(x) | |
x = torch.addcmul(shift_msa, norm_x, 1 + scale_msa) | |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp | |
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): | |
def __init__(self, embedding_dim, pooled_projection_dim): | |
super().__init__() | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") | |
def forward(self, t_emb, guidance_emb, pooled_projection_emb): | |
# We precompute timestep_emb, guidance_emb and pooled_projection_emb only once outside the forward | |
raise NotImplementedError("The implementation precomputes required embeddings, so this should be unreachable.") | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
query_dim: int, | |
heads: int = 8, | |
dim_head: int = 64, | |
dropout: float = 0.0, | |
bias: bool = False, | |
qk_norm: Optional[str] = None, | |
added_kv_proj_dim: Optional[int] = None, | |
added_proj_bias: Optional[bool] = True, | |
out_bias: bool = True, | |
eps: float = 1e-5, | |
out_dim: int = None, | |
context_pre_only=None, | |
pre_only=False, | |
elementwise_affine: bool = True, | |
): | |
super().__init__() | |
assert qk_norm == "rms_norm", "Flux uses RMSNorm" | |
self.inner_dim = out_dim if out_dim is not None else dim_head * heads | |
self.query_dim = query_dim | |
self.use_bias = bias | |
self.dropout = dropout | |
self.fused_projections = False | |
self.out_dim = out_dim if out_dim is not None else query_dim | |
self.context_pre_only = context_pre_only | |
self.pre_only = pre_only | |
self.heads = out_dim // dim_head if out_dim is not None else heads | |
self.added_proj_bias = added_proj_bias | |
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) | |
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) | |
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
if not self.pre_only: | |
self.to_out = nn.ModuleList([]) | |
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) | |
if added_kv_proj_dim is not None: | |
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) | |
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) | |
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=out_bias) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
image_rotary_emb: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
cos, sin = image_rotary_emb if image_rotary_emb is not None else (None, None) | |
query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1) | |
query, key, value = (x.unflatten(2, (self.heads, -1)) for x in (query, key, value)) | |
query = self.norm_q(query) | |
key = self.norm_k(key) | |
if encoder_hidden_states is not None: | |
query_c, key_c, value_c = self.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) | |
query_c, key_c, value_c = (x.unflatten(2, (self.heads, -1)) for x in (query_c, key_c, value_c)) | |
query_c = self.norm_added_q(query_c) | |
key_c = self.norm_added_k(key_c) | |
query = torch.cat([query_c, query], dim=1) | |
key = torch.cat([key_c, key], dim=1) | |
value = torch.cat([value_c, value], dim=1) | |
if image_rotary_emb is not None: | |
x_real, x_imag = query.unflatten(-1, (-1, 2)).unbind(-1) | |
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
query = (query.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(query) | |
x_real, x_imag = key.unflatten(-1, (-1, 2)).unbind(-1) | |
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
key = (key.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(key) | |
hidden_states = ATTENTION_OP(query, key, value) | |
hidden_states = hidden_states.flatten(2, 3) | |
if encoder_hidden_states is not None: | |
encoder_hidden_states, hidden_states = torch.split_with_sizes( | |
hidden_states, | |
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], | |
dim=1, | |
) | |
hidden_states = self.to_out[0](hidden_states) | |
encoder_hidden_states = self.to_add_out(encoder_hidden_states) | |
return hidden_states, encoder_hidden_states | |
return hidden_states | |
@torch.no_grad() | |
def fuse_projections(self): | |
device = self.to_q.weight.data.device | |
dtype = self.to_q.weight.data.dtype | |
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) | |
in_features = concatenated_weights.shape[1] | |
out_features = concatenated_weights.shape[0] | |
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) | |
self.to_qkv.weight.copy_(concatenated_weights) | |
if self.use_bias: | |
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) | |
self.to_qkv.bias.copy_(concatenated_bias) | |
if ( | |
getattr(self, "add_q_proj", None) is not None | |
and getattr(self, "add_k_proj", None) is not None | |
and getattr(self, "add_v_proj", None) is not None | |
): | |
concatenated_weights = torch.cat( | |
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] | |
) | |
in_features = concatenated_weights.shape[1] | |
out_features = concatenated_weights.shape[0] | |
self.to_added_qkv = nn.Linear( | |
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype | |
) | |
self.to_added_qkv.weight.copy_(concatenated_weights) | |
if self.added_proj_bias: | |
concatenated_bias = torch.cat( | |
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] | |
) | |
self.to_added_qkv.bias.copy_(concatenated_bias) | |
for layer in ("to_q", "to_k", "to_v", "to_added_q", "to_added_k", "to_added_v"): | |
if hasattr(self, layer): | |
module = getattr(self, layer) | |
module.to("meta") | |
delattr(self, layer) | |
self.fused_projections = True | |
class FluxPosEmbed(nn.Module): | |
def __init__(self, theta: int, axes_dim: List[int]): | |
super().__init__() | |
self.theta = theta | |
self.axes_dim = axes_dim | |
def forward(self, ids: torch.Tensor) -> torch.Tensor: | |
n_axes = ids.shape[-1] | |
cos_out = [] | |
sin_out = [] | |
for i in range(n_axes): | |
cos, sin = get_1d_rotary_pos_embed( | |
self.axes_dim[i], | |
ids[:, i], | |
theta=self.theta, | |
repeat_interleave_real=True, | |
use_real=True, | |
freqs_dtype=ROPE_PRECISION, | |
) | |
cos_out.append(cos) | |
sin_out.append(sin) | |
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None] | |
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None] | |
return freqs_cos, freqs_sin | |
class FluxSingleTransformerBlock(nn.Module): | |
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): | |
super().__init__() | |
self.mlp_hidden_dim = int(dim * mlp_ratio) | |
self.norm = AdaLayerNormZeroSingle(dim) | |
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) | |
self.act_mlp = nn.GELU(approximate="tanh") | |
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) | |
self.attn = Attention( | |
query_dim=dim, | |
dim_head=attention_head_dim, | |
heads=num_attention_heads, | |
out_dim=dim, | |
bias=True, | |
qk_norm="rms_norm", | |
eps=1e-6, | |
pre_only=True, | |
) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
temb: torch.Tensor, | |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
) -> torch.Tensor: | |
residual = hidden_states | |
norm_hidden_states, gate = self.norm(hidden_states, emb=temb) | |
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) | |
attn_output = self.attn(hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb) | |
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) | |
hidden_states = self.proj_out(hidden_states) | |
hidden_states = torch.addcmul(residual, gate, hidden_states) | |
return hidden_states | |
class FluxTransformerBlock(nn.Module): | |
def __init__( | |
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 | |
): | |
super().__init__() | |
self.norm1 = AdaLayerNormZero(dim) | |
self.norm1_context = AdaLayerNormZero(dim) | |
self.attn = Attention( | |
query_dim=dim, | |
added_kv_proj_dim=dim, | |
dim_head=attention_head_dim, | |
heads=num_attention_heads, | |
out_dim=dim, | |
context_pre_only=False, | |
bias=True, | |
qk_norm=qk_norm, | |
eps=eps, | |
) | |
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
temb: torch.Tensor, | |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
temb, temb_context = temb.chunk(2, dim=-1) | |
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) | |
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( | |
encoder_hidden_states, emb=temb_context | |
) | |
attn_output, context_attn_output = self.attn( | |
hidden_states=norm_hidden_states, | |
encoder_hidden_states=norm_encoder_hidden_states, | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = torch.addcmul(hidden_states, gate_msa, attn_output) | |
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_msa, context_attn_output) | |
norm_hidden_states = self.norm2(hidden_states) | |
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) | |
norm_hidden_states = torch.addcmul(shift_mlp, norm_hidden_states, 1 + scale_mlp) | |
norm_encoder_hidden_states = torch.addcmul(c_shift_mlp, norm_encoder_hidden_states, 1 + c_scale_mlp) | |
ff_output = self.ff(norm_hidden_states) | |
context_ff_output = self.ff_context(norm_encoder_hidden_states) | |
hidden_states = torch.addcmul(hidden_states, gate_mlp, ff_output) | |
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_mlp, context_ff_output) | |
return encoder_hidden_states, hidden_states | |
class FluxTransformer2DModel( | |
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin | |
): | |
@register_to_config | |
def __init__( | |
self, | |
patch_size: int = 1, | |
in_channels: int = 64, | |
out_channels: Optional[int] = None, | |
num_layers: int = 19, | |
num_single_layers: int = 38, | |
attention_head_dim: int = 128, | |
num_attention_heads: int = 24, | |
joint_attention_dim: int = 4096, | |
pooled_projection_dim: int = 768, | |
guidance_embeds: bool = False, | |
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), | |
): | |
super().__init__() | |
self.out_channels = out_channels or in_channels | |
self.inner_dim = num_attention_heads * attention_head_dim | |
if not guidance_embeds: | |
raise ValueError("FLUX.1-schnell inference is not yet supported.") | |
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) | |
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings( | |
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim | |
) | |
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) | |
self.x_embedder = nn.Linear(in_channels, self.inner_dim) | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
FluxTransformerBlock( | |
dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim | |
) | |
for _ in range(num_layers) | |
] | |
) | |
self.single_transformer_blocks = nn.ModuleList( | |
[ | |
FluxSingleTransformerBlock( | |
dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim | |
) | |
for _ in range(num_single_layers) | |
] | |
) | |
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) | |
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
conditioning: torch.Tensor, | |
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], | |
dt: torch.Tensor, | |
) -> torch.Tensor: | |
x_t = hidden_states | |
if CP_MESH is not None: | |
hidden_states = EquipartitionSharder.shard(hidden_states, 1, CP_MESH) | |
encoder_hidden_states = EquipartitionSharder.shard(encoder_hidden_states, 1, CP_MESH) | |
image_rotary_emb = ( | |
EquipartitionSharder.shard(image_rotary_emb[0], 1, CP_MESH), | |
EquipartitionSharder.shard(image_rotary_emb[1], 1, CP_MESH), | |
) | |
hidden_states = self.x_embedder(hidden_states) | |
adaln_linear_states = self.adaln_linear(conditioning).unsqueeze(1).chunk(self.config.num_layers, dim=-1) | |
adaln_linear_single_states = ( | |
self.adaln_linear_single(conditioning).unsqueeze(1).chunk(self.config.num_single_layers, dim=-1) | |
) | |
for i, block in enumerate(self.transformer_blocks): | |
encoder_hidden_states, hidden_states = block( | |
hidden_states=hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
temb=adaln_linear_states[i], | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | |
for i, block in enumerate(self.single_transformer_blocks): | |
hidden_states = block( | |
hidden_states=hidden_states, | |
temb=adaln_linear_single_states[i], | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] | |
hidden_states = self.norm_out(hidden_states, conditioning) | |
v = self.proj_out(hidden_states) | |
if CP_MESH is not None: | |
v = EquipartitionSharder.unshard(v, 1, CP_MESH) | |
x = x_t + dt * v | |
return x | |
@torch.no_grad() | |
def fuse_qkv_(model: FluxTransformer2DModel) -> FluxTransformer2DModel: | |
for submodule in model.modules(): | |
if not isinstance(submodule, Attention): | |
continue | |
submodule.fuse_projections() | |
@torch.no_grad() | |
def fuse_adaln_linear_(model: FluxTransformer2DModel) -> FluxTransformer2DModel: | |
adaln_linear_weights = [] | |
adaln_linear_biases = [] | |
for block in model.transformer_blocks: | |
adaln_linear_weights.append(block.norm1.linear.weight.data.clone()) | |
adaln_linear_weights.append(block.norm1_context.linear.weight.data.clone()) | |
adaln_linear_biases.append(block.norm1.linear.bias.data.clone()) | |
adaln_linear_biases.append(block.norm1_context.linear.bias.data.clone()) | |
block.norm1.linear.to("meta") | |
block.norm1_context.linear.to("meta") | |
del block.norm1.linear, block.norm1_context.linear | |
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0) | |
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0) | |
in_features = adaln_linear_weights.shape[1] | |
out_features = adaln_linear_weights.shape[0] | |
model.adaln_linear = torch.nn.Linear( | |
in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype | |
) | |
model.adaln_linear.weight.copy_(adaln_linear_weights) | |
model.adaln_linear.bias.copy_(adaln_linear_biases) | |
adaln_linear_weights = [] | |
adaln_linear_biases = [] | |
for block in model.single_transformer_blocks: | |
adaln_linear_weights.append(block.norm.linear.weight.data.clone()) | |
adaln_linear_biases.append(block.norm.linear.bias.data.clone()) | |
block.norm.linear.to("meta") | |
del block.norm.linear | |
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0) | |
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0) | |
in_features = adaln_linear_weights.shape[1] | |
out_features = adaln_linear_weights.shape[0] | |
model.adaln_linear_single = torch.nn.Linear( | |
in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype | |
) | |
model.adaln_linear_single.weight.copy_(adaln_linear_weights) | |
model.adaln_linear_single.bias.copy_(adaln_linear_biases) | |
def prepare_clip_embeddings( | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
prompt: str, | |
device: torch.device, | |
dtype: torch.dtype, | |
max_length: int = 77, | |
) -> torch.Tensor: | |
prompt = [prompt] | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_overflowing_tokens=False, | |
return_length=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) | |
prompt_embeds = prompt_embeds.pooler_output.to(dtype) | |
return prompt_embeds | |
def prepare_t5_embeddings( | |
text_encoder: T5EncoderModel, | |
tokenizer: T5TokenizerFast, | |
prompt: str, | |
device: torch.device, | |
dtype: torch.dtype, | |
max_length: int = 512, | |
) -> torch.Tensor: | |
prompt = [prompt] | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0] | |
prompt_embeds = prompt_embeds.to(dtype) | |
return prompt_embeds | |
def prepare_latent_image_ids(height: int, width: int, device: torch.device, dtype: torch.dtype): | |
latent_image_ids = torch.zeros(height, width, 3) | |
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] | |
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] | |
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape | |
latent_image_ids = latent_image_ids.reshape( | |
latent_image_id_height * latent_image_id_width, latent_image_id_channels | |
).contiguous() | |
return latent_image_ids.to(device=device, dtype=dtype) | |
def precompute_guidance_embeds(transformer: FluxTransformer2DModel): | |
embeds = {} | |
for guidance_scale in SUPPORTED_GUIDANCE_SCALES: | |
guidance = torch.full([1], guidance_scale, device="cuda", dtype=torch.float32) | |
guidance = transformer.time_text_embed.guidance_embedder( | |
transformer.time_text_embed.time_proj(guidance * 1000.0).to(dtype=torch.bfloat16) | |
) | |
embeds[f"{guidance_scale:.1f}"] = guidance | |
return embeds | |
def precompute_timestep_embeds(transformer: FluxTransformer2DModel): | |
embeds = {} | |
image_seq_len = LATENT_HEIGHT * LATENT_WIDTH | |
for num_inference_steps in range(MIN_STEPS, MAX_STEPS + 1): | |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
mu = B + image_seq_len * M | |
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** 1.0) | |
sigmas = torch.from_numpy(sigmas).to(DEVICE, dtype=torch.float32) | |
sigmas = torch.cat([sigmas, sigmas.new_zeros(1)]) | |
timesteps = (sigmas * 1000.0).to(DTYPE) | |
temb = transformer.time_text_embed.time_proj(timesteps) | |
temb = transformer.time_text_embed.timestep_embedder(temb.to(DTYPE)) | |
embeds[num_inference_steps] = (sigmas, temb) | |
return embeds | |
def precompute_embeds(transformer: FluxTransformer2DModel, save_dir: str): | |
save_dir = pathlib.Path(save_dir) | |
save_dir.mkdir(parents=True, exist_ok=True) | |
guidance_path = save_dir / "guidance_embeds.pt" | |
timestep_path = save_dir / "timestep_embeds.pt" | |
if guidance_path.exists(): | |
guidance_embeds = torch.load(guidance_path, map_location=DEVICE, weights_only=True) | |
print(f'Loaded precomputed guidance embeddings from "{guidance_path.as_posix()}"') | |
else: | |
guidance_embeds = precompute_guidance_embeds(transformer) | |
if CP_MESH is None or CP_MESH.get_local_rank() == 0: | |
torch.save(guidance_embeds, guidance_path.as_posix()) | |
print(f'Precomputed guidance embeddings saved to "{save_dir.as_posix()}"') | |
if timestep_path.exists(): | |
timestep_embeds = torch.load(timestep_path, map_location=DEVICE, weights_only=True) | |
print(f'Loaded precomputed timestep embeddings from "{timestep_path.as_posix()}"') | |
else: | |
timestep_embeds = precompute_timestep_embeds(transformer) | |
if CP_MESH is None or CP_MESH.get_local_rank() == 0: | |
torch.save(timestep_embeds, timestep_path.as_posix()) | |
print(f'Precomputed timestep embeddings saved to "{save_dir.as_posix()}"') | |
return guidance_embeds, timestep_embeds | |
def capture_cudagraph( | |
model: FluxTransformer2DModel, | |
latents: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
conditioning: torch.Tensor, | |
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], | |
dt: torch.Tensor, | |
): | |
print("Capturing CUDA graph...") | |
static_latents = latents.clone() | |
static_conditioning = conditioning.clone() | |
static_dt = dt.clone() | |
graph = torch.cuda.CUDAGraph() | |
with torch.cuda.graph(graph): | |
static_x = model( | |
hidden_states=static_latents, | |
encoder_hidden_states=encoder_hidden_states, | |
conditioning=static_conditioning, | |
image_rotary_emb=image_rotary_emb, | |
dt=static_dt, | |
) | |
return graph, static_latents, static_conditioning, static_dt, static_x | |
@torch.inference_mode() | |
@torch.compiler.set_stance("default", skip_guard_eval_unsafe=True) | |
def main( | |
model_id: str, | |
working_dir: str = "/tmp/flux_precomputation", | |
num_inference_steps: int = 28, | |
mode: str = "none", | |
cache_dir: Optional[str] = None, | |
enable_cudagraph: bool = False, | |
enable_profiling: bool = False, | |
): | |
prompt = "A cat holding a sign that says 'Hello, World!'" | |
transformer = FluxTransformer2DModel.from_pretrained( | |
model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=DTYPE | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=DTYPE | |
) | |
text_encoder_2 = T5EncoderModel.from_pretrained( | |
model_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=DTYPE | |
) | |
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) | |
tokenizer_2 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) | |
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=DTYPE) | |
image_processor = VaeImageProcessor(vae_scale_factor=8 * 2) | |
fuse_qkv_(transformer) | |
fuse_adaln_linear_(transformer) | |
[x.to(DEVICE) for x in (transformer, text_encoder, text_encoder_2, vae)] | |
if mode != "none": | |
transformer = torch.compile(transformer, mode=mode, fullgraph=True, dynamic=False) | |
guidance_embeds, timestep_embeds = precompute_embeds(transformer, working_dir) | |
latents = torch.randn((BATCH_SIZE, LATENT_HEIGHT * LATENT_WIDTH, IN_CHANNELS * 2 * 2), dtype=DTYPE, device=DEVICE) | |
pooled_projections = prepare_clip_embeddings(text_encoder, tokenizer, prompt, DEVICE, DTYPE) | |
encoder_hidden_states = prepare_t5_embeddings( | |
text_encoder_2, tokenizer_2, prompt, DEVICE, DTYPE, T5_SEQUENCE_LENGTH | |
) | |
# <precompute> | |
guidance_conditioning = guidance_embeds[f"{GUIDANCE_SCALE:.1f}"] | |
sigmas, timestep_conditioning = timestep_embeds[num_inference_steps] | |
pooled_projections = transformer.time_text_embed.text_embedder(pooled_projections) | |
encoder_hidden_states = transformer.context_embedder(encoder_hidden_states) | |
# </precompute> | |
img_ids = prepare_latent_image_ids(LATENT_HEIGHT, LATENT_WIDTH, device=DEVICE, dtype=DTYPE) | |
text_ids = torch.zeros(T5_SEQUENCE_LENGTH, 3).to(device=DEVICE, dtype=DTYPE) | |
ids = torch.cat([text_ids, img_ids], dim=0).float() | |
image_rotary_emb = transformer.pos_embed(ids) | |
dt = sigmas[1:] - sigmas[:-1] | |
print("Warmup step...") | |
for _ in range(2): | |
conditioning = pointwise_add3_silu(timestep_conditioning[0, :], guidance_conditioning, pooled_projections) | |
_ = transformer( | |
hidden_states=latents, | |
encoder_hidden_states=encoder_hidden_states, | |
conditioning=conditioning, | |
image_rotary_emb=image_rotary_emb, | |
dt=dt[0], | |
) | |
torch.cuda.synchronize() | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
context = ( | |
profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True) | |
if enable_profiling | |
else contextlib.nullcontext() | |
) | |
if not enable_cudagraph: | |
with context as ctx: | |
start_event.record() | |
for i in range(num_inference_steps): | |
conditioning = pointwise_add3_silu( | |
timestep_conditioning[i, :], guidance_conditioning, pooled_projections | |
) | |
latents = transformer( | |
hidden_states=latents, | |
encoder_hidden_states=encoder_hidden_states, | |
conditioning=conditioning, | |
image_rotary_emb=image_rotary_emb, | |
dt=dt[i], | |
) | |
end_event.record() | |
torch.cuda.synchronize() | |
else: | |
graph, static_latents, static_conditioning, static_dt, static_x = capture_cudagraph( | |
transformer, | |
latents, | |
encoder_hidden_states, | |
pooled_projections + guidance_conditioning, | |
image_rotary_emb, | |
dt[0], | |
) | |
static_x.copy_(latents) | |
with context as ctx: | |
start_event.record() | |
for i in range(num_inference_steps): | |
conditioning = pointwise_add3_silu( | |
timestep_conditioning[i, :], guidance_conditioning, pooled_projections | |
) | |
static_latents.copy_(static_x) | |
static_conditioning.copy_(conditioning) | |
static_dt.copy_(dt[i]) | |
graph.replay() | |
end_event.record() | |
torch.cuda.synchronize() | |
latents = static_x | |
total_time = start_event.elapsed_time(end_event) / 1000.0 | |
print(f"time: {total_time:.5f}s") | |
if enable_profiling: | |
print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) | |
ctx.export_chrome_trace(f"dump_benchmark_flux_normal_{mode}.json") | |
if CP_MESH is None or CP_MESH.get_local_rank() == 0: | |
latents = latents.reshape(-1, LATENT_HEIGHT, LATENT_WIDTH, IN_CHANNELS, 2, 2) | |
latents = latents.permute(0, 3, 1, 4, 2, 5) | |
latents = latents.flatten(4, 5).flatten(2, 3) | |
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor | |
image = vae.decode(latents, return_dict=False)[0] | |
image = image_processor.postprocess(image, output_type="pil")[0] | |
image.save("output.png") | |
if __name__ == "__main__": | |
model_id = "black-forest-labs/FLUX.1-dev" | |
cache_dir = None | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--cp_degree", type=int, default=1, help="Degree of context parallelism.") | |
parser.add_argument("--working_dir", type=str, default="/tmp/flux_precomputation") | |
parser.add_argument("--enable_cudagraph", action="store_true") | |
parser.add_argument( | |
"--compile_mode", | |
type=str, | |
default="none", | |
choices=["none", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], | |
) | |
parser.add_argument("--num_inference_steps", type=int, default=28) | |
parser.add_argument("--reduce_rope_precision", action="store_true") | |
parser.add_argument("--attention_backend", type=str, default="torch_cudnn", choices=["torch_cudnn", "fa2", "fa3"]) | |
parser.add_argument("--enable_profiling", action="store_true") | |
parser.add_argument("--seed", type=int, default=42) | |
args = parser.parse_args() | |
if args.enable_cudagraph and args.compile_mode not in ["none", "default", "max-autotune-no-cudagraphs"]: | |
raise ValueError("CUDAGraph requires compile_mode to be 'default' or 'max-autotune-no-cudagraphs'.") | |
if args.num_inference_steps > 50 or args.num_inference_steps < 2: | |
raise ValueError("num_inference_steps must be between 2 and 50.") | |
ROPE_PRECISION = torch.bfloat16 if args.reduce_rope_precision else torch.float32 | |
if args.attention_backend == "fa2": | |
ATTENTION_OP = _attention_flash_attn_2 | |
elif args.attention_backend == "fa3": | |
ATTENTION_OP = _attention_flash_attn_3 | |
torch.manual_seed(args.seed) | |
try: | |
if args.cp_degree > 1: | |
dist.init_process_group("nccl") | |
rank, world_size = dist.get_rank(), dist.get_world_size() | |
torch.cuda.set_device(rank) | |
CP_MESH = dist.device_mesh.init_device_mesh("cuda", [world_size], mesh_dim_names=["cp"]) | |
if args.enable_profiling and args.enable_cudagraph: | |
torch.profiler._utils._init_for_cuda_graphs() | |
main( | |
model_id=model_id, | |
working_dir=args.working_dir, | |
num_inference_steps=args.num_inference_steps, | |
mode=args.compile_mode, | |
cache_dir=cache_dir, | |
enable_cudagraph=args.enable_cudagraph, | |
enable_profiling=args.enable_profiling, | |
) | |
finally: | |
if dist.is_initialized(): | |
dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment