Created
June 19, 2025 08:53
-
-
Save a-r-r-o-w/2551f05b67380818da22d94fedebcadf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import contextlib | |
import math | |
import pathlib | |
from typing import List, Optional, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.profiler._utils | |
import torch._dynamo.config | |
import torch._inductor.config | |
import torch._higher_order_ops.auto_functionalize as af | |
import triton | |
import triton.language as tl | |
from torch.profiler import profile, record_function, ProfilerActivity | |
try: | |
from flash_attn import flash_attn_func | |
except: | |
print("Flash Attention 2 not found.") | |
try: | |
from flash_attn_interface import flash_attn_func as flash_attn_3_func | |
except: | |
print("Flash Attention 3 not found.") | |
from diffusers.models.autoencoders import AutoencoderKL | |
from diffusers.image_processor import VaeImageProcessor | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin | |
from diffusers.models.attention import FeedForward | |
from diffusers.models.embeddings import ( | |
get_1d_rotary_pos_embed, | |
TimestepEmbedding, | |
Timesteps, | |
PixArtAlphaTextProjection, | |
) | |
from diffusers.models.cache_utils import CacheMixin | |
from diffusers.models.modeling_utils import ModelMixin | |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
torch._dynamo.config.inline_inbuilt_nn_modules = False | |
torch._inductor.config.coordinate_descent_tuning = True | |
# torch._inductor.config.coordinate_descent_check_all_directions = True | |
torch._inductor.config.coordinate_descent_search_radius = 1 | |
torch._inductor.config.epilogue_fusion = False | |
torch._inductor.config.fx_graph_cache = True | |
# torch._inductor.config.max_fusion_size = 128 | |
# torch._inductor.config.max_pointwise_cat_inputs = 16 | |
# torch._inductor.config.force_pointwise_cat = True | |
torch._inductor.config.aggressive_fusion = True | |
# torch._inductor.config.triton.unique_kernel_names = True | |
af.auto_functionalized_v2._cacheable = True | |
af.auto_functionalized._cacheable = True | |
ROPE_PRECISION = torch.float32 | |
SUPPORTED_GUIDANCE_SCALES = [i / 2 for i in range(41)] # 0, 0.5, 1.0, ..., 20.0 | |
# Constants | |
DEVICE = "cuda" | |
DTYPE = torch.bfloat16 | |
MIN_STEPS = 2 | |
MAX_STEPS = 50 | |
SPATIAL_COMPRESSION_RATIO = 8 | |
PATCH_SIZE = 1 | |
IN_CHANNELS = 16 | |
BASE_IMAGE_SEQ_LEN = 256 | |
MAX_IMAGE_SEQ_LEN = 4096 | |
BASE_SHIFT = 0.5 | |
MAX_SHIFT = 1.15 | |
M = (MAX_SHIFT - BASE_SHIFT) / (MAX_IMAGE_SEQ_LEN - BASE_IMAGE_SEQ_LEN) | |
B = BASE_SHIFT - M * BASE_IMAGE_SEQ_LEN | |
# The following parameters are fixed for now | |
BATCH_SIZE = 1 | |
HEIGHT = 1024 | |
WIDTH = 1024 | |
LATENT_HEIGHT = HEIGHT // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2 | |
LATENT_WIDTH = WIDTH // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2 | |
T5_SEQUENCE_LENGTH = 512 | |
GUIDANCE_SCALE = 4.0 | |
def _attention_torch_cudnn(query, key, value): | |
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) | |
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): | |
out = torch.nn.functional.scaled_dot_product_attention(query, key, value) | |
out = out.transpose(1, 2).contiguous() | |
return out | |
def _attention_flash_attn_2(query, key, value): | |
return flash_attn_func(query, key, value) | |
# For fullgraph=True tracing to be compatible | |
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") | |
def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
out, lse = flash_attn_3_func(query, key, value) | |
return out | |
@torch.library.register_fake("flash_attn_3::_flash_attn_forward") | |
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
return torch.empty_like(query) | |
def _attention_flash_attn_3(query, key, value): | |
out = _wrapped_flash_attn_3(query, key, value) | |
return out | |
ATTENTION_OP = _attention_torch_cudnn | |
@torch.compile | |
def pointwise_add3_silu(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: | |
return torch.nn.functional.silu(x + y + z) | |
class AdaLayerNormContinuous(nn.Module): | |
def __init__( | |
self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine=True, eps=1e-5, bias=True | |
): | |
super().__init__() | |
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) | |
self.norm = torch.nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
emb = self.linear(emb) | |
scale, shift = emb.unsqueeze(1).chunk(2, dim=-1) | |
norm_x = self.norm(x) | |
x = torch.addcmul(shift, norm_x, 1 + scale) | |
return x | |
class AdaLayerNormZeroSingle(nn.Module): | |
def __init__(self, embedding_dim: int, bias=True): | |
super().__init__() | |
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) | |
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None): | |
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) | |
norm_x = self.norm(x) | |
x = torch.addcmul(shift_msa, norm_x, 1 + scale_msa) | |
return x, gate_msa | |
class AdaLayerNormZero(nn.Module): | |
def __init__(self, embedding_dim: int, bias=True): | |
super().__init__() | |
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) | |
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor): | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) | |
norm_x = self.norm(x) | |
x = torch.addcmul(shift_msa, norm_x, 1 + scale_msa) | |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp | |
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): | |
def __init__(self, embedding_dim, pooled_projection_dim): | |
super().__init__() | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") | |
def forward(self, t_emb, guidance_emb, pooled_projection_emb): | |
# We precompute timestep_emb, guidance_emb and pooled_projection_emb only once outside the forward | |
raise NotImplementedError("The implementation precomputes required embeddings, so this should be unreachable.") | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
query_dim: int, | |
heads: int = 8, | |
dim_head: int = 64, | |
dropout: float = 0.0, | |
bias: bool = False, | |
qk_norm: Optional[str] = None, | |
added_kv_proj_dim: Optional[int] = None, | |
added_proj_bias: Optional[bool] = True, | |
out_bias: bool = True, | |
eps: float = 1e-5, | |
out_dim: int = None, | |
context_pre_only=None, | |
pre_only=False, | |
elementwise_affine: bool = True, | |
): | |
super().__init__() | |
assert qk_norm == "rms_norm", "Flux uses RMSNorm" | |
self.inner_dim = out_dim if out_dim is not None else dim_head * heads | |
self.query_dim = query_dim | |
self.use_bias = bias | |
self.dropout = dropout | |
self.fused_projections = False | |
self.out_dim = out_dim if out_dim is not None else query_dim | |
self.context_pre_only = context_pre_only | |
self.pre_only = pre_only | |
self.heads = out_dim // dim_head if out_dim is not None else heads | |
self.added_proj_bias = added_proj_bias | |
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) | |
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) | |
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
if not self.pre_only: | |
self.to_out = nn.ModuleList([]) | |
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) | |
if added_kv_proj_dim is not None: | |
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) | |
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) | |
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=out_bias) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
image_rotary_emb: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
cos, sin = image_rotary_emb if image_rotary_emb is not None else (None, None) | |
query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1) | |
query, key, value = (x.unflatten(2, (self.heads, -1)) for x in (query, key, value)) | |
query = self.norm_q(query) | |
key = self.norm_k(key) | |
if encoder_hidden_states is not None: | |
query_c, key_c, value_c = self.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) | |
query_c, key_c, value_c = (x.unflatten(2, (self.heads, -1)) for x in (query_c, key_c, value_c)) | |
query_c = self.norm_added_q(query_c) | |
key_c = self.norm_added_k(key_c) | |
query = torch.cat([query_c, query], dim=1) | |
key = torch.cat([key_c, key], dim=1) | |
value = torch.cat([value_c, value], dim=1) | |
if image_rotary_emb is not None: | |
x_real, x_imag = query.unflatten(-1, (-1, 2)).unbind(-1) | |
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
query = (query.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(query) | |
x_real, x_imag = key.unflatten(-1, (-1, 2)).unbind(-1) | |
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
key = (key.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(key) | |
hidden_states = ATTENTION_OP(query, key, value) | |
hidden_states = hidden_states.flatten(2, 3) | |
if encoder_hidden_states is not None: | |
encoder_hidden_states, hidden_states = torch.split_with_sizes( | |
hidden_states, | |
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], | |
dim=1, | |
) | |
hidden_states = self.to_out[0](hidden_states) | |
encoder_hidden_states = self.to_add_out(encoder_hidden_states) | |
return hidden_states, encoder_hidden_states | |
return hidden_states | |
@torch.no_grad() | |
def fuse_projections(self): | |
device = self.to_q.weight.data.device | |
dtype = self.to_q.weight.data.dtype | |
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) | |
in_features = concatenated_weights.shape[1] | |
out_features = concatenated_weights.shape[0] | |
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) | |
self.to_qkv.weight.copy_(concatenated_weights) | |
if self.use_bias: | |
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) | |
self.to_qkv.bias.copy_(concatenated_bias) | |
if ( | |
getattr(self, "add_q_proj", None) is not None | |
and getattr(self, "add_k_proj", None) is not None | |
and getattr(self, "add_v_proj", None) is not None | |
): | |
concatenated_weights = torch.cat( | |
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] | |
) | |
in_features = concatenated_weights.shape[1] | |
out_features = concatenated_weights.shape[0] | |
self.to_added_qkv = nn.Linear( | |
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype | |
) | |
self.to_added_qkv.weight.copy_(concatenated_weights) | |
if self.added_proj_bias: | |
concatenated_bias = torch.cat( | |
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] | |
) | |
self.to_added_qkv.bias.copy_(concatenated_bias) | |
for layer in ("to_q", "to_k", "to_v", "to_added_q", "to_added_k", "to_added_v"): | |
if hasattr(self, layer): | |
module = getattr(self, layer) | |
module.to("meta") | |
delattr(self, layer) | |
self.fused_projections = True | |
class FluxPosEmbed(nn.Module): | |
def __init__(self, theta: int, axes_dim: List[int]): | |
super().__init__() | |
self.theta = theta | |
self.axes_dim = axes_dim | |
def forward(self, ids: torch.Tensor) -> torch.Tensor: | |
n_axes = ids.shape[-1] | |
cos_out = [] | |
sin_out = [] | |
for i in range(n_axes): | |
cos, sin = get_1d_rotary_pos_embed( | |
self.axes_dim[i], | |
ids[:, i], | |
theta=self.theta, | |
repeat_interleave_real=True, | |
use_real=True, | |
freqs_dtype=ROPE_PRECISION, | |
) | |
cos_out.append(cos) | |
sin_out.append(sin) | |
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None] | |
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None] | |
return freqs_cos, freqs_sin | |
class FluxSingleTransformerBlock(nn.Module): | |
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): | |
super().__init__() | |
self.mlp_hidden_dim = int(dim * mlp_ratio) | |
self.norm = AdaLayerNormZeroSingle(dim) | |
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) | |
self.act_mlp = nn.GELU(approximate="tanh") | |
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) | |
self.attn = Attention( | |
query_dim=dim, | |
dim_head=attention_head_dim, | |
heads=num_attention_heads, | |
out_dim=dim, | |
bias=True, | |
qk_norm="rms_norm", | |
eps=1e-6, | |
pre_only=True, | |
) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
temb: torch.Tensor, | |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
) -> torch.Tensor: | |
residual = hidden_states | |
norm_hidden_states, gate = self.norm(hidden_states, emb=temb) | |
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) | |
attn_output = self.attn(hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb) | |
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) | |
hidden_states = self.proj_out(hidden_states) | |
hidden_states = torch.addcmul(residual, gate, hidden_states) | |
return hidden_states | |
class FluxTransformerBlock(nn.Module): | |
def __init__( | |
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 | |
): | |
super().__init__() | |
self.norm1 = AdaLayerNormZero(dim) | |
self.norm1_context = AdaLayerNormZero(dim) | |
self.attn = Attention( | |
query_dim=dim, | |
added_kv_proj_dim=dim, | |
dim_head=attention_head_dim, | |
heads=num_attention_heads, | |
out_dim=dim, | |
context_pre_only=False, | |
bias=True, | |
qk_norm=qk_norm, | |
eps=eps, | |
) | |
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
temb: torch.Tensor, | |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
temb, temb_context = temb.chunk(2, dim=-1) | |
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) | |
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( | |
encoder_hidden_states, emb=temb_context | |
) | |
attn_output, context_attn_output = self.attn( | |
hidden_states=norm_hidden_states, | |
encoder_hidden_states=norm_encoder_hidden_states, | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = torch.addcmul(hidden_states, gate_msa, attn_output) | |
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_msa, context_attn_output) | |
norm_hidden_states = self.norm2(hidden_states) | |
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) | |
norm_hidden_states = torch.addcmul(shift_mlp, norm_hidden_states, 1 + scale_mlp) | |
norm_encoder_hidden_states = torch.addcmul(c_shift_mlp, norm_encoder_hidden_states, 1 + c_scale_mlp) | |
ff_output = self.ff(norm_hidden_states) | |
context_ff_output = self.ff_context(norm_encoder_hidden_states) | |
hidden_states = torch.addcmul(hidden_states, gate_mlp, ff_output) | |
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_mlp, context_ff_output) | |
return encoder_hidden_states, hidden_states | |
class FluxTransformer2DModel( | |
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin | |
): | |
@register_to_config | |
def __init__( | |
self, | |
patch_size: int = 1, | |
in_channels: int = 64, | |
out_channels: Optional[int] = None, | |
num_layers: int = 19, | |
num_single_layers: int = 38, | |
attention_head_dim: int = 128, | |
num_attention_heads: int = 24, | |
joint_attention_dim: int = 4096, | |
pooled_projection_dim: int = 768, | |
guidance_embeds: bool = False, | |
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), | |
): | |
super().__init__() | |
self.out_channels = out_channels or in_channels | |
self.inner_dim = num_attention_heads * attention_head_dim | |
if not guidance_embeds: | |
raise ValueError("FLUX.1-schnell inference is not yet supported.") | |
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) | |
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings( | |
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim | |
) | |
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) | |
self.x_embedder = nn.Linear(in_channels, self.inner_dim) | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
FluxTransformerBlock( | |
dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim | |
) | |
for _ in range(num_layers) | |
] | |
) | |
self.single_transformer_blocks = nn.ModuleList( | |
[ | |
FluxSingleTransformerBlock( | |
dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim | |
) | |
for _ in range(num_single_layers) | |
] | |
) | |
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) | |
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
conditioning: torch.Tensor, | |
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], | |
dt: torch.Tensor, | |
) -> torch.Tensor: | |
x_t = hidden_states | |
hidden_states = self.x_embedder(hidden_states) | |
adaln_linear_states = self.adaln_linear(conditioning).unsqueeze(1).chunk(self.config.num_layers, dim=-1) | |
adaln_linear_single_states = ( | |
self.adaln_linear_single(conditioning).unsqueeze(1).chunk(self.config.num_single_layers, dim=-1) | |
) | |
for i, block in enumerate(self.transformer_blocks): | |
encoder_hidden_states, hidden_states = block( | |
hidden_states=hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
temb=adaln_linear_states[i], | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | |
for i, block in enumerate(self.single_transformer_blocks): | |
hidden_states = block( | |
hidden_states=hidden_states, | |
temb=adaln_linear_single_states[i], | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] | |
hidden_states = self.norm_out(hidden_states, conditioning) | |
v = self.proj_out(hidden_states) | |
x = x_t + dt * v | |
return x | |
@torch.no_grad() | |
def fuse_qkv_(model: FluxTransformer2DModel) -> FluxTransformer2DModel: | |
for submodule in model.modules(): | |
if not isinstance(submodule, Attention): | |
continue | |
submodule.fuse_projections() | |
@torch.no_grad() | |
def fuse_adaln_linear_(model: FluxTransformer2DModel) -> FluxTransformer2DModel: | |
adaln_linear_weights = [] | |
adaln_linear_biases = [] | |
for block in model.transformer_blocks: | |
adaln_linear_weights.append(block.norm1.linear.weight.data.clone()) | |
adaln_linear_weights.append(block.norm1_context.linear.weight.data.clone()) | |
adaln_linear_biases.append(block.norm1.linear.bias.data.clone()) | |
adaln_linear_biases.append(block.norm1_context.linear.bias.data.clone()) | |
block.norm1.linear.to("meta") | |
block.norm1_context.linear.to("meta") | |
del block.norm1.linear, block.norm1_context.linear | |
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0) | |
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0) | |
in_features = adaln_linear_weights.shape[1] | |
out_features = adaln_linear_weights.shape[0] | |
model.adaln_linear = torch.nn.Linear( | |
in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype | |
) | |
model.adaln_linear.weight.copy_(adaln_linear_weights) | |
model.adaln_linear.bias.copy_(adaln_linear_biases) | |
adaln_linear_weights = [] | |
adaln_linear_biases = [] | |
for block in model.single_transformer_blocks: | |
adaln_linear_weights.append(block.norm.linear.weight.data.clone()) | |
adaln_linear_biases.append(block.norm.linear.bias.data.clone()) | |
block.norm.linear.to("meta") | |
del block.norm.linear | |
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0) | |
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0) | |
in_features = adaln_linear_weights.shape[1] | |
out_features = adaln_linear_weights.shape[0] | |
model.adaln_linear_single = torch.nn.Linear( | |
in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype | |
) | |
model.adaln_linear_single.weight.copy_(adaln_linear_weights) | |
model.adaln_linear_single.bias.copy_(adaln_linear_biases) | |
def prepare_clip_embeddings( | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
prompt: str, | |
device: torch.device, | |
dtype: torch.dtype, | |
max_length: int = 77, | |
) -> torch.Tensor: | |
prompt = [prompt] | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_overflowing_tokens=False, | |
return_length=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) | |
prompt_embeds = prompt_embeds.pooler_output.to(dtype) | |
return prompt_embeds | |
def prepare_t5_embeddings( | |
text_encoder: T5EncoderModel, | |
tokenizer: T5TokenizerFast, | |
prompt: str, | |
device: torch.device, | |
dtype: torch.dtype, | |
max_length: int = 512, | |
) -> torch.Tensor: | |
prompt = [prompt] | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0] | |
prompt_embeds = prompt_embeds.to(dtype) | |
return prompt_embeds | |
def prepare_latent_image_ids(height: int, width: int, device: torch.device, dtype: torch.dtype): | |
latent_image_ids = torch.zeros(height, width, 3) | |
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] | |
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] | |
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape | |
latent_image_ids = latent_image_ids.reshape( | |
latent_image_id_height * latent_image_id_width, latent_image_id_channels | |
).contiguous() | |
return latent_image_ids.to(device=device, dtype=dtype) | |
def precompute_guidance_embeds(transformer: FluxTransformer2DModel): | |
embeds = {} | |
for guidance_scale in SUPPORTED_GUIDANCE_SCALES: | |
guidance = torch.full([1], guidance_scale, device="cuda", dtype=torch.float32) | |
guidance = transformer.time_text_embed.guidance_embedder( | |
transformer.time_text_embed.time_proj(guidance * 1000.0).to(DTYPE) | |
) | |
embeds[f"{guidance_scale:.1f}"] = guidance | |
return embeds | |
def precompute_timestep_embeds(transformer: FluxTransformer2DModel): | |
embeds = {} | |
image_seq_len = LATENT_HEIGHT * LATENT_WIDTH | |
for num_inference_steps in range(MIN_STEPS, MAX_STEPS + 1): | |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
mu = B + image_seq_len * M | |
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** 1.0) | |
sigmas = torch.from_numpy(sigmas).to(DEVICE, dtype=torch.float32) | |
sigmas = torch.cat([sigmas, sigmas.new_zeros(1)]) | |
timesteps = (sigmas * 1000.0).to(DTYPE) | |
temb = transformer.time_text_embed.time_proj(timesteps) | |
temb = transformer.time_text_embed.timestep_embedder(temb.to(DTYPE)) | |
embeds[num_inference_steps] = (sigmas, temb) | |
return embeds | |
def precompute_embeds(transformer: FluxTransformer2DModel, save_dir: str): | |
save_dir = pathlib.Path(save_dir) | |
save_dir.mkdir(parents=True, exist_ok=True) | |
guidance_path = save_dir / "guidance_embeds.pt" | |
timestep_path = save_dir / "timestep_embeds.pt" | |
if guidance_path.exists(): | |
guidance_embeds = torch.load(guidance_path, map_location=DEVICE, weights_only=True) | |
print(f'Loaded precomputed guidance embeddings from "{guidance_path.as_posix()}"') | |
else: | |
guidance_embeds = precompute_guidance_embeds(transformer) | |
torch.save(guidance_embeds, guidance_path.as_posix()) | |
print(f'Precomputed guidance embeddings saved to "{save_dir.as_posix()}"') | |
if timestep_path.exists(): | |
timestep_embeds = torch.load(timestep_path, map_location=DEVICE, weights_only=True) | |
print(f'Loaded precomputed timestep embeddings from "{timestep_path.as_posix()}"') | |
else: | |
timestep_embeds = precompute_timestep_embeds(transformer) | |
torch.save(timestep_embeds, timestep_path.as_posix()) | |
print(f'Precomputed timestep embeddings saved to "{save_dir.as_posix()}"') | |
return guidance_embeds, timestep_embeds | |
def capture_cudagraph( | |
model: FluxTransformer2DModel, | |
latents: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
conditioning: torch.Tensor, | |
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], | |
dt: torch.Tensor, | |
): | |
print("Capturing CUDA graph...") | |
static_latents = latents.clone() | |
static_conditioning = conditioning.clone() | |
static_dt = dt.clone() | |
graph = torch.cuda.CUDAGraph() | |
with torch.cuda.graph(graph): | |
static_x = model( | |
hidden_states=static_latents, | |
encoder_hidden_states=encoder_hidden_states, | |
conditioning=static_conditioning, | |
image_rotary_emb=image_rotary_emb, | |
dt=static_dt, | |
) | |
return graph, static_latents, static_conditioning, static_dt, static_x | |
@torch.inference_mode() | |
# @torch.compiler.set_stance("default", skip_guard_eval_unsafe=True) | |
def main( | |
model_id: str, | |
prompt: str, | |
seed: int, | |
working_dir: str = "/tmp/flux_precomputation", | |
num_inference_steps: int = 28, | |
mode: str = "none", | |
cache_dir: Optional[str] = None, | |
enable_cudagraph: bool = False, | |
enable_profiling: bool = False, | |
attention_backend: str = None, | |
): | |
transformer = FluxTransformer2DModel.from_pretrained( | |
model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=DTYPE | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=DTYPE | |
) | |
text_encoder_2 = T5EncoderModel.from_pretrained( | |
model_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=DTYPE | |
) | |
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) | |
tokenizer_2 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) | |
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=DTYPE) | |
image_processor = VaeImageProcessor(vae_scale_factor=8 * 2) | |
fuse_qkv_(transformer) | |
fuse_adaln_linear_(transformer) | |
[x.to(DEVICE) for x in (transformer, text_encoder, text_encoder_2, vae)] | |
if mode != "none": | |
transformer = torch.compile(transformer, mode=mode, fullgraph=True, dynamic=False) | |
generator = torch.Generator(DEVICE).manual_seed(seed) | |
guidance_embeds, timestep_embeds = precompute_embeds(transformer, working_dir) | |
latents = torch.randn((BATCH_SIZE, LATENT_HEIGHT * LATENT_WIDTH, IN_CHANNELS * 2 * 2), dtype=DTYPE, device=DEVICE, generator=generator) | |
pooled_projections = prepare_clip_embeddings(text_encoder, tokenizer, prompt, DEVICE, DTYPE) | |
encoder_hidden_states = prepare_t5_embeddings( | |
text_encoder_2, tokenizer_2, prompt, DEVICE, DTYPE, T5_SEQUENCE_LENGTH | |
) | |
# <precompute> | |
guidance_conditioning = guidance_embeds[f"{GUIDANCE_SCALE:.1f}"] | |
sigmas, timestep_conditioning = timestep_embeds[num_inference_steps] | |
pooled_projections = transformer.time_text_embed.text_embedder(pooled_projections) | |
encoder_hidden_states = transformer.context_embedder(encoder_hidden_states) | |
# </precompute> | |
img_ids = prepare_latent_image_ids(LATENT_HEIGHT, LATENT_WIDTH, device=DEVICE, dtype=DTYPE) | |
text_ids = torch.zeros(T5_SEQUENCE_LENGTH, 3).to(device=DEVICE, dtype=DTYPE) | |
ids = torch.cat([text_ids, img_ids], dim=0).float() | |
image_rotary_emb = transformer.pos_embed(ids) | |
dt = sigmas[1:] - sigmas[:-1] | |
print("Warmup step...") | |
for _ in range(2): | |
conditioning = pointwise_add3_silu(timestep_conditioning[0, :], guidance_conditioning, pooled_projections) | |
_ = transformer( | |
hidden_states=latents, | |
encoder_hidden_states=encoder_hidden_states, | |
conditioning=conditioning, | |
image_rotary_emb=image_rotary_emb, | |
dt=dt[0], | |
) | |
torch.cuda.synchronize() | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
context = ( | |
profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True) | |
if enable_profiling | |
else contextlib.nullcontext() | |
) | |
if not enable_cudagraph: | |
with context as ctx: | |
start_event.record() | |
for i in range(num_inference_steps): | |
conditioning = pointwise_add3_silu( | |
timestep_conditioning[i, :], guidance_conditioning, pooled_projections | |
) | |
latents = transformer( | |
hidden_states=latents, | |
encoder_hidden_states=encoder_hidden_states, | |
conditioning=conditioning, | |
image_rotary_emb=image_rotary_emb, | |
dt=dt[i], | |
) | |
end_event.record() | |
torch.cuda.synchronize() | |
else: | |
graph, static_latents, static_conditioning, static_dt, static_x = capture_cudagraph( | |
transformer, | |
latents, | |
encoder_hidden_states, | |
pooled_projections + guidance_conditioning, | |
image_rotary_emb, | |
dt[0], | |
) | |
static_x.copy_(latents) | |
with context as ctx: | |
start_event.record() | |
for i in range(num_inference_steps): | |
conditioning = pointwise_add3_silu( | |
timestep_conditioning[i, :], guidance_conditioning, pooled_projections | |
) | |
static_latents.copy_(static_x) | |
static_conditioning.copy_(conditioning) | |
static_dt.copy_(dt[i]) | |
graph.replay() | |
end_event.record() | |
torch.cuda.synchronize() | |
latents = static_x | |
total_time = start_event.elapsed_time(end_event) / 1000.0 | |
print(f"time: {total_time:.5f}s") | |
if enable_profiling: | |
print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) | |
ctx.export_chrome_trace(f"dump_benchmark_flux_normal_{mode}.json") | |
latents = latents.reshape(-1, LATENT_HEIGHT, LATENT_WIDTH, IN_CHANNELS, 2, 2) | |
latents = latents.permute(0, 3, 1, 4, 2, 5) | |
latents = latents.flatten(4, 5).flatten(2, 3) | |
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor | |
image = vae.decode(latents, return_dict=False)[0] | |
image = image_processor.postprocess(image, output_type="pil")[0] | |
rope_precision = str(ROPE_PRECISION).split(".")[-1] | |
image.save(f"dump_output---mode-{mode}---cudagraph-{enable_cudagraph}---attention-{attention_backend}---rope-{rope_precision}.png") | |
if __name__ == "__main__": | |
model_id = "black-forest-labs/FLUX.1-dev" | |
cache_dir = None | |
default_prompt = "The King of Hearts card transforms into a 3D hologram that appears to be made of cosmic energy. As the King emerges, stars and galaxies swirl around him, creating a sense of traveling through the universe. The King's attire is adorned with celestial patterns, and his crown is a glowing star cluster. The hologram floats in front of you, with the background shifting through different cosmic scenes, from nebulae to black holes. Atmosphere: Perfect for space-themed events, science fiction conventions, or futuristic tech expos." | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--working_dir", type=str, default="/tmp/flux_precomputation") | |
parser.add_argument("--prompt", type=str, default=default_prompt) | |
parser.add_argument("--enable_cudagraph", action="store_true") | |
parser.add_argument( | |
"--compile_mode", | |
type=str, | |
default="none", | |
choices=["none", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], | |
) | |
parser.add_argument("--num_inference_steps", type=int, default=28) | |
parser.add_argument("--reduce_rope_precision", action="store_true") | |
parser.add_argument("--attention_backend", type=str, default="torch_cudnn", choices=["torch_cudnn", "fa2", "fa3"]) | |
parser.add_argument("--enable_profiling", action="store_true") | |
parser.add_argument("--seed", type=int, default=42) | |
args = parser.parse_args() | |
if args.enable_cudagraph and args.compile_mode not in ["none", "default", "max-autotune-no-cudagraphs"]: | |
raise ValueError("CUDAGraph requires compile_mode to be 'default' or 'max-autotune-no-cudagraphs'.") | |
if args.num_inference_steps > 50 or args.num_inference_steps < 2: | |
raise ValueError("num_inference_steps must be between 2 and 50.") | |
ROPE_PRECISION = torch.bfloat16 if args.reduce_rope_precision else torch.float32 | |
if args.attention_backend == "fa2": | |
ATTENTION_OP = _attention_flash_attn_2 | |
elif args.attention_backend == "fa3": | |
ATTENTION_OP = _attention_flash_attn_3 | |
torch.manual_seed(args.seed) | |
if args.enable_profiling and args.enable_cudagraph: | |
torch.profiler._utils._init_for_cuda_graphs() | |
main( | |
model_id=model_id, | |
prompt=args.prompt, | |
seed=args.seed, | |
working_dir=args.working_dir, | |
num_inference_steps=args.num_inference_steps, | |
mode=args.compile_mode, | |
cache_dir=cache_dir, | |
enable_cudagraph=args.enable_cudagraph, | |
enable_profiling=args.enable_profiling, | |
attention_backend=args.attention_backend, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment