Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Last active June 16, 2025 04:58
Show Gist options
  • Save a-r-r-o-w/599dd5972fbaa75c8f50c3bf71227676 to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/599dd5972fbaa75c8f50c3bf71227676 to your computer and use it in GitHub Desktop.
import argparse
import contextlib
import math
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch._inductor.config
import torch._higher_order_ops.auto_functionalize as af
import triton
import triton.language as tl
from torch.profiler import profile, record_function, ProfilerActivity
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.attention import FeedForward
from diffusers.models.embeddings import get_1d_rotary_pos_embed
from diffusers.models.cache_utils import CacheMixin
from diffusers.models.embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
)
from diffusers.models.modeling_utils import ModelMixin
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
torch._inductor.config.coordinate_descent_tuning = True
# torch._inductor.config.epilogue_fusion = False
torch._inductor.config.fx_graph_cache = True
# torch._inductor.config.max_fusion_size = 128
# torch._inductor.config.max_pointwise_cat_inputs = 16
# torch._inductor.config.force_pointwise_cat = True
torch._inductor.config.aggressive_fusion = True
# torch._inductor.config.triton.unique_kernel_names = True
af.auto_functionalized_v2._cacheable = True
af.auto_functionalized._cacheable = True
ROPE_PRECISION = torch.float32
class AdaLayerNormContinuous(nn.Module):
def __init__(
self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine=True, eps=1e-5, bias=True
):
super().__init__()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
self.norm = torch.nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
emb = self.linear(emb)
scale, shift = emb.unsqueeze(1).chunk(2, dim=-1)
norm_x = self.norm(x)
x = torch.addcmul(shift, norm_x, 1 + scale)
return x
class AdaLayerNormZeroSingle(nn.Module):
def __init__(self, embedding_dim: int, bias=True):
super().__init__()
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None):
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
norm_x = self.norm(x)
x = torch.addcmul(shift_msa, norm_x, 1 + scale_msa)
return x, gate_msa
class AdaLayerNormZero(nn.Module):
def __init__(self, embedding_dim: int, bias=True):
super().__init__()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x: torch.Tensor, emb: torch.Tensor):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
norm_x = self.norm(x)
x = torch.addcmul(shift_msa, norm_x, 1 + scale_msa)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
):
super().__init__()
assert qk_norm == "rms_norm", "Flux uses RMSNorm"
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=out_bias)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
cos, sin = image_rotary_emb if image_rotary_emb is not None else (None, None)
query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1)
query, key, value = (x.unflatten(2, (self.heads, -1)).transpose(1, 2) for x in (query, key, value))
query = self.norm_q(query)
key = self.norm_k(key)
if encoder_hidden_states is not None:
query_c, key_c, value_c = self.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
query_c, key_c, value_c = (
x.unflatten(2, (self.heads, -1)).transpose(1, 2) for x in (query_c, key_c, value_c)
)
query_c = self.norm_added_q(query_c)
key_c = self.norm_added_k(key_c)
query = torch.cat([query_c, query], dim=2)
key = torch.cat([key_c, key], dim=2)
value = torch.cat([value_c, value], dim=2)
if image_rotary_emb is not None:
x_real, x_imag = query.reshape(*query.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
query = (query.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(query)
x_real, x_imag = key.reshape(*key.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
key = (key.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(key)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = torch.split_with_sizes(
hidden_states,
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]],
dim=1,
)
hidden_states = self.to_out[0](hidden_states)
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
return hidden_states
@torch.no_grad()
def fuse_projections(self):
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
for layer in ("to_q", "to_k", "to_v", "to_added_q", "to_added_k", "to_added_v"):
if hasattr(self, layer):
module = getattr(self, layer)
module.to("meta")
delattr(self, layer)
self.fused_projections = True
class FluxPosEmbed(nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
ids[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=ROPE_PRECISION,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, None]
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, None]
return freqs_cos, freqs_sin
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
attn_output = self.attn(hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states = self.proj_out(hidden_states)
hidden_states = torch.addcmul(residual, gate, hidden_states)
return hidden_states
class FluxTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = Attention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
qk_norm=qk_norm,
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
temb, temb_context = temb.chunk(2, dim=-1)
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb_context
)
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = torch.addcmul(hidden_states, gate_msa, attn_output)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_msa, context_attn_output)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_hidden_states = torch.addcmul(shift_mlp, norm_hidden_states, 1 + scale_mlp)
norm_encoder_hidden_states = torch.addcmul(c_shift_mlp, norm_encoder_hidden_states, 1 + c_scale_mlp)
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states = torch.addcmul(hidden_states, gate_mlp, ff_output)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_mlp, context_ff_output)
return encoder_hidden_states, hidden_states
class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
):
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim
)
for _ in range(num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim
)
for _ in range(num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
pooled_projections: torch.Tensor,
timestep: torch.LongTensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
guidance: Optional[torch.Tensor],
dt: torch.Tensor,
) -> torch.Tensor:
x_t = hidden_states
timestep = (timestep * 1000).to(hidden_states.dtype)
if guidance is not None:
guidance = (guidance * 1000).to(hidden_states.dtype)
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
temb = torch.nn.functional.silu(temb)
adaln_linear_states = self.adaln_linear(temb).unsqueeze(1).chunk(self.config.num_layers, dim=-1)
adaln_linear_states_single = self.adaln_linear_single(temb).unsqueeze(1).chunk(self.config.num_single_layers, dim=-1)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
for i, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=adaln_linear_states[i],
image_rotary_emb=image_rotary_emb,
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for i, block in enumerate(self.single_transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
temb=adaln_linear_states_single[i],
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
v = self.proj_out(hidden_states)
x = x_t + dt * v
return x
@torch.no_grad()
def fuse_qkv_(model: FluxTransformer2DModel) -> FluxTransformer2DModel:
for submodule in model.modules():
if not isinstance(submodule, Attention):
continue
submodule.fuse_projections()
@torch.no_grad()
def fuse_adaln_linear_(model: FluxTransformer2DModel) -> FluxTransformer2DModel:
adaln_linear_weights = []
adaln_linear_biases = []
for block in model.transformer_blocks:
adaln_linear_weights.append(block.norm1.linear.weight.data.clone())
adaln_linear_weights.append(block.norm1_context.linear.weight.data.clone())
adaln_linear_biases.append(block.norm1.linear.bias.data.clone())
adaln_linear_biases.append(block.norm1_context.linear.bias.data.clone())
block.norm1.linear.to("meta")
block.norm1_context.linear.to("meta")
del block.norm1.linear, block.norm1_context.linear
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0)
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0)
in_features = adaln_linear_weights.shape[1]
out_features = adaln_linear_weights.shape[0]
model.adaln_linear = torch.nn.Linear(in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype)
model.adaln_linear.weight.copy_(adaln_linear_weights)
model.adaln_linear.bias.copy_(adaln_linear_biases)
adaln_linear_weights = []
adaln_linear_biases = []
for block in model.single_transformer_blocks:
adaln_linear_weights.append(block.norm.linear.weight.data.clone())
adaln_linear_biases.append(block.norm.linear.bias.data.clone())
block.norm.linear.to("meta")
del block.norm.linear
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0)
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0)
in_features = adaln_linear_weights.shape[1]
out_features = adaln_linear_weights.shape[0]
model.adaln_linear_single = torch.nn.Linear(
in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype
)
model.adaln_linear_single.weight.copy_(adaln_linear_weights)
model.adaln_linear_single.bias.copy_(adaln_linear_biases)
def prepare_clip_embeddings(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
prompt: str,
device: torch.device,
dtype: torch.dtype,
max_length: int = 77,
) -> torch.Tensor:
prompt = [prompt]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
prompt_embeds = prompt_embeds.pooler_output.to(dtype)
return prompt_embeds
def prepare_t5_embeddings(
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
prompt: str,
device: torch.device,
dtype: torch.dtype,
max_length: int = 512,
) -> torch.Tensor:
prompt = [prompt]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
prompt_embeds = prompt_embeds.to(dtype)
return prompt_embeds
def prepare_latent_image_ids(height: int, width: int, device: torch.device, dtype: torch.dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
).contiguous()
return latent_image_ids.to(device=device, dtype=dtype)
@torch.inference_mode()
def main(
model_id: str,
num_inference_steps: int = 28,
mode: str = "none",
cache_dir: Optional[str] = None,
enable_profiling: bool = False,
):
device = "cuda"
dtype = torch.bfloat16
BASE_IMAGE_SEQ_LEN = 256
MAX_IMAGE_SEQ_LEN = 4096
BASE_SHIFT = 0.5
MAX_SHIFT = 1.15
M = (MAX_SHIFT - BASE_SHIFT) / (MAX_IMAGE_SEQ_LEN - BASE_IMAGE_SEQ_LEN)
B = BASE_SHIFT - M * BASE_IMAGE_SEQ_LEN
prompt = "A cat holding a sign that says 'Hello, World!'"
transformer = FluxTransformer2DModel.from_pretrained(
model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=dtype
)
text_encoder = CLIPTextModel.from_pretrained(
model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=dtype
)
text_encoder_2 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=dtype
)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
tokenizer_2 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=dtype)
image_processor = VaeImageProcessor(vae_scale_factor=8 * 2)
fuse_qkv_(transformer)
fuse_adaln_linear_(transformer)
[x.to(device) for x in (transformer, text_encoder, text_encoder_2, vae)]
if mode != "none":
transformer = torch.compile(transformer, mode=mode, fullgraph=True, dynamic=False)
batch_size = 1
height = 1024
width = 1024
spatial_compression_ratio = 8
patch_size = 1
in_channels = 16
t5_sequence_length = 512
latent_height = height // (spatial_compression_ratio * patch_size) // 2
latent_width = width // (spatial_compression_ratio * patch_size) // 2
guidance_scale = 4.0
latents = torch.randn((batch_size, latent_height * latent_width, in_channels * 2 * 2), dtype=dtype, device=device)
pooled_projections = prepare_clip_embeddings(text_encoder, tokenizer, prompt, device, dtype)
encoder_hidden_states = prepare_t5_embeddings(
text_encoder_2, tokenizer_2, prompt, device, dtype, t5_sequence_length
)
guidance = torch.full([batch_size], guidance_scale, device=device, dtype=torch.float32)
img_ids = prepare_latent_image_ids(latent_height, latent_width, device=device, dtype=dtype)
text_ids = torch.zeros(t5_sequence_length, 3).to(device=device, dtype=dtype)
ids = torch.cat([text_ids, img_ids], dim=0).float()
image_rotary_emb = transformer.pos_embed(ids)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = B + image_seq_len * M
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** 1.0)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = torch.cat([sigmas, sigmas.new_zeros(1)])
sigmas_dtype = sigmas.to(dtype=dtype)
dt = sigmas[1:] - sigmas[:-1]
print("Warmup step...")
for _ in range(2):
_ = transformer(
hidden_states=latents,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projections,
timestep=sigmas_dtype[0].expand(batch_size),
image_rotary_emb=image_rotary_emb,
guidance=guidance,
dt=dt[0],
)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
context = profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True) if enable_profiling else contextlib.nullcontext()
with context as ctx:
start_event.record()
for i in range(num_inference_steps):
latents = transformer(
hidden_states=latents,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projections,
timestep=sigmas_dtype[i].expand(batch_size),
image_rotary_emb=image_rotary_emb,
guidance=guidance,
dt=dt[i],
)
end_event.record()
torch.cuda.synchronize()
total_time = start_event.elapsed_time(end_event) / 1000.0
print(f"time: {total_time:.2f}s")
if enable_profiling:
print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
ctx.export_chrome_trace(f"dump_benchmark_flux_normal_{mode}.json")
latents = latents.reshape(-1, latent_height, latent_width, in_channels, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.flatten(4, 5).flatten(2, 3)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="pil")[0]
image.save("output.png")
@torch.inference_mode()
def main_cudagraph(
model_id: str,
num_inference_steps: int = 28,
mode: str = "none",
cache_dir: Optional[str] = None,
enable_profiling: bool = False,
):
device = "cuda"
dtype = torch.bfloat16
BASE_IMAGE_SEQ_LEN = 256
MAX_IMAGE_SEQ_LEN = 4096
BASE_SHIFT = 0.5
MAX_SHIFT = 1.15
M = (MAX_SHIFT - BASE_SHIFT) / (MAX_IMAGE_SEQ_LEN - BASE_IMAGE_SEQ_LEN)
B = BASE_SHIFT - M * BASE_IMAGE_SEQ_LEN
prompt = "A cat holding a sign that says 'Hello, World!'"
transformer = FluxTransformer2DModel.from_pretrained(
model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=dtype
)
text_encoder = CLIPTextModel.from_pretrained(
model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=dtype
)
text_encoder_2 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=dtype
)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
tokenizer_2 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=dtype)
image_processor = VaeImageProcessor(vae_scale_factor=8 * 2)
fuse_qkv_(transformer)
fuse_adaln_linear_(transformer)
[x.to(device) for x in (transformer, text_encoder, text_encoder_2, vae)]
if mode != "none":
transformer = torch.compile(transformer, mode=mode, fullgraph=True, dynamic=False)
batch_size = 1
height = 1024
width = 1024
spatial_compression_ratio = 8
patch_size = 1
in_channels = 16
t5_sequence_length = 512
latent_height = height // (spatial_compression_ratio * patch_size) // 2
latent_width = width // (spatial_compression_ratio * patch_size) // 2
guidance_scale = 4.0
latents = torch.randn((batch_size, latent_height * latent_width, in_channels * 2 * 2), dtype=dtype, device=device)
pooled_projections = prepare_clip_embeddings(text_encoder, tokenizer, prompt, device, dtype)
encoder_hidden_states = prepare_t5_embeddings(
text_encoder_2, tokenizer_2, prompt, device, dtype, t5_sequence_length
)
guidance = torch.full([batch_size], guidance_scale, device=device, dtype=torch.float32)
img_ids = prepare_latent_image_ids(latent_height, latent_width, device=device, dtype=dtype)
text_ids = torch.zeros(t5_sequence_length, 3).to(device=device, dtype=dtype)
ids = torch.cat([text_ids, img_ids], dim=0)
image_rotary_emb = transformer.pos_embed(ids)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = B + image_seq_len * M
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** 1.0)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = torch.cat([sigmas, sigmas.new_zeros(1)])
sigmas_dtype = sigmas.to(dtype=dtype)
dt = sigmas[1:] - sigmas[:-1]
print("Warmup step...")
for _ in range(2):
_ = transformer(
hidden_states=latents,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projections,
timestep=sigmas_dtype[0].expand(batch_size),
image_rotary_emb=image_rotary_emb,
guidance=guidance,
dt=dt[0],
)
torch.cuda.synchronize()
print("Capturing CUDA graph...")
graph = torch.cuda.CUDAGraph()
static_latents = latents.clone()
static_timestep = sigmas_dtype[1].expand(batch_size).clone()
static_dt = dt[0].clone()
with torch.cuda.graph(graph):
static_v = transformer(
hidden_states=static_latents,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projections,
timestep=static_timestep,
image_rotary_emb=image_rotary_emb,
guidance=guidance,
dt=static_dt,
)
torch.cuda.synchronize()
static_v.copy_(latents)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
context = profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True) if enable_profiling else contextlib.nullcontext()
with context as ctx:
start_event.record()
for i in range(num_inference_steps):
static_latents.copy_(static_v)
static_timestep.copy_(sigmas_dtype[i].expand(batch_size))
static_dt.copy_(dt[i])
graph.replay()
end_event.record()
torch.cuda.synchronize()
latents = static_v
total_time = start_event.elapsed_time(end_event) / 1000.0
print(f"time: {total_time:.2f}s")
if enable_profiling:
print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
ctx.export_chrome_trace(f"dump_benchmark_flux_cudagraph_{mode}.json")
latents = latents.reshape(-1, latent_height, latent_width, in_channels, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.flatten(4, 5).flatten(2, 3)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="pil")[0]
image.save("output.png")
if __name__ == "__main__":
model_id = "black-forest-labs/FLUX.1-dev"
cache_dir = "/raid/.cache/huggingface"
parser = argparse.ArgumentParser()
parser.add_argument("--enable_cudagraph", action="store_true")
parser.add_argument("--compile_mode", type=str, default="none", choices=["none", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"])
parser.add_argument("--num_inference_steps", type=int, default=28)
parser.add_argument("--reduce_rope_precision", action="store_true")
parser.add_argument("--enable_profiling", action="store_true")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
if args.enable_cudagraph and args.compile_mode not in ["none", "default", "max-autotune-no-cudagraphs"]:
raise ValueError("CUDAGraph requires compile_mode to be 'default' or 'max-autotune-no-cudagraphs'.")
ROPE_PRECISION = torch.bfloat16 if args.reduce_rope_precision else torch.float32
torch.manual_seed(args.seed)
if not args.enable_cudagraph:
main(
model_id=model_id,
num_inference_steps=args.num_inference_steps,
mode=args.compile_mode,
cache_dir=cache_dir,
enable_profiling=args.enable_profiling,
)
else:
main_cudagraph(
model_id=model_id,
num_inference_steps=args.num_inference_steps,
mode=args.compile_mode,
cache_dir=cache_dir,
enable_profiling=args.enable_profiling,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment