Last active
June 16, 2025 04:58
-
-
Save a-r-r-o-w/599dd5972fbaa75c8f50c3bf71227676 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import contextlib | |
import math | |
from typing import List, Optional, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch._inductor.config | |
import torch._higher_order_ops.auto_functionalize as af | |
import triton | |
import triton.language as tl | |
from torch.profiler import profile, record_function, ProfilerActivity | |
from diffusers import AutoencoderKL | |
from diffusers.image_processor import VaeImageProcessor | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin | |
from diffusers.models.attention import FeedForward | |
from diffusers.models.embeddings import get_1d_rotary_pos_embed | |
from diffusers.models.cache_utils import CacheMixin | |
from diffusers.models.embeddings import ( | |
CombinedTimestepGuidanceTextProjEmbeddings, | |
CombinedTimestepTextProjEmbeddings, | |
) | |
from diffusers.models.modeling_utils import ModelMixin | |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
torch._inductor.config.coordinate_descent_tuning = True | |
# torch._inductor.config.epilogue_fusion = False | |
torch._inductor.config.fx_graph_cache = True | |
# torch._inductor.config.max_fusion_size = 128 | |
# torch._inductor.config.max_pointwise_cat_inputs = 16 | |
# torch._inductor.config.force_pointwise_cat = True | |
torch._inductor.config.aggressive_fusion = True | |
# torch._inductor.config.triton.unique_kernel_names = True | |
af.auto_functionalized_v2._cacheable = True | |
af.auto_functionalized._cacheable = True | |
ROPE_PRECISION = torch.float32 | |
class AdaLayerNormContinuous(nn.Module): | |
def __init__( | |
self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine=True, eps=1e-5, bias=True | |
): | |
super().__init__() | |
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) | |
self.norm = torch.nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
emb = self.linear(emb) | |
scale, shift = emb.unsqueeze(1).chunk(2, dim=-1) | |
norm_x = self.norm(x) | |
x = torch.addcmul(shift, norm_x, 1 + scale) | |
return x | |
class AdaLayerNormZeroSingle(nn.Module): | |
def __init__(self, embedding_dim: int, bias=True): | |
super().__init__() | |
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) | |
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None): | |
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) | |
norm_x = self.norm(x) | |
x = torch.addcmul(shift_msa, norm_x, 1 + scale_msa) | |
return x, gate_msa | |
class AdaLayerNormZero(nn.Module): | |
def __init__(self, embedding_dim: int, bias=True): | |
super().__init__() | |
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) | |
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor): | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) | |
norm_x = self.norm(x) | |
x = torch.addcmul(shift_msa, norm_x, 1 + scale_msa) | |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
query_dim: int, | |
heads: int = 8, | |
dim_head: int = 64, | |
dropout: float = 0.0, | |
bias: bool = False, | |
qk_norm: Optional[str] = None, | |
added_kv_proj_dim: Optional[int] = None, | |
added_proj_bias: Optional[bool] = True, | |
out_bias: bool = True, | |
eps: float = 1e-5, | |
out_dim: int = None, | |
context_pre_only=None, | |
pre_only=False, | |
elementwise_affine: bool = True, | |
): | |
super().__init__() | |
assert qk_norm == "rms_norm", "Flux uses RMSNorm" | |
self.inner_dim = out_dim if out_dim is not None else dim_head * heads | |
self.query_dim = query_dim | |
self.use_bias = bias | |
self.dropout = dropout | |
self.fused_projections = False | |
self.out_dim = out_dim if out_dim is not None else query_dim | |
self.context_pre_only = context_pre_only | |
self.pre_only = pre_only | |
self.heads = out_dim // dim_head if out_dim is not None else heads | |
self.added_proj_bias = added_proj_bias | |
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) | |
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) | |
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
if not self.pre_only: | |
self.to_out = nn.ModuleList([]) | |
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) | |
if added_kv_proj_dim is not None: | |
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) | |
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) | |
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=out_bias) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
image_rotary_emb: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
cos, sin = image_rotary_emb if image_rotary_emb is not None else (None, None) | |
query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1) | |
query, key, value = (x.unflatten(2, (self.heads, -1)).transpose(1, 2) for x in (query, key, value)) | |
query = self.norm_q(query) | |
key = self.norm_k(key) | |
if encoder_hidden_states is not None: | |
query_c, key_c, value_c = self.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) | |
query_c, key_c, value_c = ( | |
x.unflatten(2, (self.heads, -1)).transpose(1, 2) for x in (query_c, key_c, value_c) | |
) | |
query_c = self.norm_added_q(query_c) | |
key_c = self.norm_added_k(key_c) | |
query = torch.cat([query_c, query], dim=2) | |
key = torch.cat([key_c, key], dim=2) | |
value = torch.cat([value_c, value], dim=2) | |
if image_rotary_emb is not None: | |
x_real, x_imag = query.reshape(*query.shape[:-1], -1, 2).unbind(-1) | |
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
query = (query.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(query) | |
x_real, x_imag = key.reshape(*key.shape[:-1], -1, 2).unbind(-1) | |
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
key = (key.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(key) | |
hidden_states = torch.nn.functional.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).flatten(2) | |
if encoder_hidden_states is not None: | |
encoder_hidden_states, hidden_states = torch.split_with_sizes( | |
hidden_states, | |
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], | |
dim=1, | |
) | |
hidden_states = self.to_out[0](hidden_states) | |
encoder_hidden_states = self.to_add_out(encoder_hidden_states) | |
return hidden_states, encoder_hidden_states | |
return hidden_states | |
@torch.no_grad() | |
def fuse_projections(self): | |
device = self.to_q.weight.data.device | |
dtype = self.to_q.weight.data.dtype | |
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) | |
in_features = concatenated_weights.shape[1] | |
out_features = concatenated_weights.shape[0] | |
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) | |
self.to_qkv.weight.copy_(concatenated_weights) | |
if self.use_bias: | |
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) | |
self.to_qkv.bias.copy_(concatenated_bias) | |
if ( | |
getattr(self, "add_q_proj", None) is not None | |
and getattr(self, "add_k_proj", None) is not None | |
and getattr(self, "add_v_proj", None) is not None | |
): | |
concatenated_weights = torch.cat( | |
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] | |
) | |
in_features = concatenated_weights.shape[1] | |
out_features = concatenated_weights.shape[0] | |
self.to_added_qkv = nn.Linear( | |
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype | |
) | |
self.to_added_qkv.weight.copy_(concatenated_weights) | |
if self.added_proj_bias: | |
concatenated_bias = torch.cat( | |
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] | |
) | |
self.to_added_qkv.bias.copy_(concatenated_bias) | |
for layer in ("to_q", "to_k", "to_v", "to_added_q", "to_added_k", "to_added_v"): | |
if hasattr(self, layer): | |
module = getattr(self, layer) | |
module.to("meta") | |
delattr(self, layer) | |
self.fused_projections = True | |
class FluxPosEmbed(nn.Module): | |
def __init__(self, theta: int, axes_dim: List[int]): | |
super().__init__() | |
self.theta = theta | |
self.axes_dim = axes_dim | |
def forward(self, ids: torch.Tensor) -> torch.Tensor: | |
n_axes = ids.shape[-1] | |
cos_out = [] | |
sin_out = [] | |
for i in range(n_axes): | |
cos, sin = get_1d_rotary_pos_embed( | |
self.axes_dim[i], | |
ids[:, i], | |
theta=self.theta, | |
repeat_interleave_real=True, | |
use_real=True, | |
freqs_dtype=ROPE_PRECISION, | |
) | |
cos_out.append(cos) | |
sin_out.append(sin) | |
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, None] | |
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, None] | |
return freqs_cos, freqs_sin | |
class FluxSingleTransformerBlock(nn.Module): | |
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): | |
super().__init__() | |
self.mlp_hidden_dim = int(dim * mlp_ratio) | |
self.norm = AdaLayerNormZeroSingle(dim) | |
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) | |
self.act_mlp = nn.GELU(approximate="tanh") | |
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) | |
self.attn = Attention( | |
query_dim=dim, | |
dim_head=attention_head_dim, | |
heads=num_attention_heads, | |
out_dim=dim, | |
bias=True, | |
qk_norm="rms_norm", | |
eps=1e-6, | |
pre_only=True, | |
) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
temb: torch.Tensor, | |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
) -> torch.Tensor: | |
residual = hidden_states | |
norm_hidden_states, gate = self.norm(hidden_states, emb=temb) | |
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) | |
attn_output = self.attn(hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb) | |
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) | |
hidden_states = self.proj_out(hidden_states) | |
hidden_states = torch.addcmul(residual, gate, hidden_states) | |
return hidden_states | |
class FluxTransformerBlock(nn.Module): | |
def __init__( | |
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 | |
): | |
super().__init__() | |
self.norm1 = AdaLayerNormZero(dim) | |
self.norm1_context = AdaLayerNormZero(dim) | |
self.attn = Attention( | |
query_dim=dim, | |
added_kv_proj_dim=dim, | |
dim_head=attention_head_dim, | |
heads=num_attention_heads, | |
out_dim=dim, | |
context_pre_only=False, | |
bias=True, | |
qk_norm=qk_norm, | |
eps=eps, | |
) | |
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
temb: torch.Tensor, | |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
temb, temb_context = temb.chunk(2, dim=-1) | |
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) | |
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( | |
encoder_hidden_states, emb=temb_context | |
) | |
attn_output, context_attn_output = self.attn( | |
hidden_states=norm_hidden_states, | |
encoder_hidden_states=norm_encoder_hidden_states, | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = torch.addcmul(hidden_states, gate_msa, attn_output) | |
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_msa, context_attn_output) | |
norm_hidden_states = self.norm2(hidden_states) | |
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) | |
norm_hidden_states = torch.addcmul(shift_mlp, norm_hidden_states, 1 + scale_mlp) | |
norm_encoder_hidden_states = torch.addcmul(c_shift_mlp, norm_encoder_hidden_states, 1 + c_scale_mlp) | |
ff_output = self.ff(norm_hidden_states) | |
context_ff_output = self.ff_context(norm_encoder_hidden_states) | |
hidden_states = torch.addcmul(hidden_states, gate_mlp, ff_output) | |
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_mlp, context_ff_output) | |
return encoder_hidden_states, hidden_states | |
class FluxTransformer2DModel( | |
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin | |
): | |
@register_to_config | |
def __init__( | |
self, | |
patch_size: int = 1, | |
in_channels: int = 64, | |
out_channels: Optional[int] = None, | |
num_layers: int = 19, | |
num_single_layers: int = 38, | |
attention_head_dim: int = 128, | |
num_attention_heads: int = 24, | |
joint_attention_dim: int = 4096, | |
pooled_projection_dim: int = 768, | |
guidance_embeds: bool = False, | |
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), | |
): | |
super().__init__() | |
self.out_channels = out_channels or in_channels | |
self.inner_dim = num_attention_heads * attention_head_dim | |
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) | |
text_time_guidance_cls = ( | |
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings | |
) | |
self.time_text_embed = text_time_guidance_cls( | |
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim | |
) | |
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) | |
self.x_embedder = nn.Linear(in_channels, self.inner_dim) | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
FluxTransformerBlock( | |
dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim | |
) | |
for _ in range(num_layers) | |
] | |
) | |
self.single_transformer_blocks = nn.ModuleList( | |
[ | |
FluxSingleTransformerBlock( | |
dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim | |
) | |
for _ in range(num_single_layers) | |
] | |
) | |
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) | |
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
pooled_projections: torch.Tensor, | |
timestep: torch.LongTensor, | |
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], | |
guidance: Optional[torch.Tensor], | |
dt: torch.Tensor, | |
) -> torch.Tensor: | |
x_t = hidden_states | |
timestep = (timestep * 1000).to(hidden_states.dtype) | |
if guidance is not None: | |
guidance = (guidance * 1000).to(hidden_states.dtype) | |
temb = ( | |
self.time_text_embed(timestep, pooled_projections) | |
if guidance is None | |
else self.time_text_embed(timestep, guidance, pooled_projections) | |
) | |
temb = torch.nn.functional.silu(temb) | |
adaln_linear_states = self.adaln_linear(temb).unsqueeze(1).chunk(self.config.num_layers, dim=-1) | |
adaln_linear_states_single = self.adaln_linear_single(temb).unsqueeze(1).chunk(self.config.num_single_layers, dim=-1) | |
hidden_states = self.x_embedder(hidden_states) | |
encoder_hidden_states = self.context_embedder(encoder_hidden_states) | |
for i, block in enumerate(self.transformer_blocks): | |
encoder_hidden_states, hidden_states = block( | |
hidden_states=hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
temb=adaln_linear_states[i], | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | |
for i, block in enumerate(self.single_transformer_blocks): | |
hidden_states = block( | |
hidden_states=hidden_states, | |
temb=adaln_linear_states_single[i], | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] | |
hidden_states = self.norm_out(hidden_states, temb) | |
v = self.proj_out(hidden_states) | |
x = x_t + dt * v | |
return x | |
@torch.no_grad() | |
def fuse_qkv_(model: FluxTransformer2DModel) -> FluxTransformer2DModel: | |
for submodule in model.modules(): | |
if not isinstance(submodule, Attention): | |
continue | |
submodule.fuse_projections() | |
@torch.no_grad() | |
def fuse_adaln_linear_(model: FluxTransformer2DModel) -> FluxTransformer2DModel: | |
adaln_linear_weights = [] | |
adaln_linear_biases = [] | |
for block in model.transformer_blocks: | |
adaln_linear_weights.append(block.norm1.linear.weight.data.clone()) | |
adaln_linear_weights.append(block.norm1_context.linear.weight.data.clone()) | |
adaln_linear_biases.append(block.norm1.linear.bias.data.clone()) | |
adaln_linear_biases.append(block.norm1_context.linear.bias.data.clone()) | |
block.norm1.linear.to("meta") | |
block.norm1_context.linear.to("meta") | |
del block.norm1.linear, block.norm1_context.linear | |
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0) | |
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0) | |
in_features = adaln_linear_weights.shape[1] | |
out_features = adaln_linear_weights.shape[0] | |
model.adaln_linear = torch.nn.Linear(in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype) | |
model.adaln_linear.weight.copy_(adaln_linear_weights) | |
model.adaln_linear.bias.copy_(adaln_linear_biases) | |
adaln_linear_weights = [] | |
adaln_linear_biases = [] | |
for block in model.single_transformer_blocks: | |
adaln_linear_weights.append(block.norm.linear.weight.data.clone()) | |
adaln_linear_biases.append(block.norm.linear.bias.data.clone()) | |
block.norm.linear.to("meta") | |
del block.norm.linear | |
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0) | |
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0) | |
in_features = adaln_linear_weights.shape[1] | |
out_features = adaln_linear_weights.shape[0] | |
model.adaln_linear_single = torch.nn.Linear( | |
in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype | |
) | |
model.adaln_linear_single.weight.copy_(adaln_linear_weights) | |
model.adaln_linear_single.bias.copy_(adaln_linear_biases) | |
def prepare_clip_embeddings( | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
prompt: str, | |
device: torch.device, | |
dtype: torch.dtype, | |
max_length: int = 77, | |
) -> torch.Tensor: | |
prompt = [prompt] | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_overflowing_tokens=False, | |
return_length=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) | |
prompt_embeds = prompt_embeds.pooler_output.to(dtype) | |
return prompt_embeds | |
def prepare_t5_embeddings( | |
text_encoder: T5EncoderModel, | |
tokenizer: T5TokenizerFast, | |
prompt: str, | |
device: torch.device, | |
dtype: torch.dtype, | |
max_length: int = 512, | |
) -> torch.Tensor: | |
prompt = [prompt] | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0] | |
prompt_embeds = prompt_embeds.to(dtype) | |
return prompt_embeds | |
def prepare_latent_image_ids(height: int, width: int, device: torch.device, dtype: torch.dtype): | |
latent_image_ids = torch.zeros(height, width, 3) | |
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] | |
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] | |
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape | |
latent_image_ids = latent_image_ids.reshape( | |
latent_image_id_height * latent_image_id_width, latent_image_id_channels | |
).contiguous() | |
return latent_image_ids.to(device=device, dtype=dtype) | |
@torch.inference_mode() | |
def main( | |
model_id: str, | |
num_inference_steps: int = 28, | |
mode: str = "none", | |
cache_dir: Optional[str] = None, | |
enable_profiling: bool = False, | |
): | |
device = "cuda" | |
dtype = torch.bfloat16 | |
BASE_IMAGE_SEQ_LEN = 256 | |
MAX_IMAGE_SEQ_LEN = 4096 | |
BASE_SHIFT = 0.5 | |
MAX_SHIFT = 1.15 | |
M = (MAX_SHIFT - BASE_SHIFT) / (MAX_IMAGE_SEQ_LEN - BASE_IMAGE_SEQ_LEN) | |
B = BASE_SHIFT - M * BASE_IMAGE_SEQ_LEN | |
prompt = "A cat holding a sign that says 'Hello, World!'" | |
transformer = FluxTransformer2DModel.from_pretrained( | |
model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=dtype | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=dtype | |
) | |
text_encoder_2 = T5EncoderModel.from_pretrained( | |
model_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=dtype | |
) | |
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) | |
tokenizer_2 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) | |
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=dtype) | |
image_processor = VaeImageProcessor(vae_scale_factor=8 * 2) | |
fuse_qkv_(transformer) | |
fuse_adaln_linear_(transformer) | |
[x.to(device) for x in (transformer, text_encoder, text_encoder_2, vae)] | |
if mode != "none": | |
transformer = torch.compile(transformer, mode=mode, fullgraph=True, dynamic=False) | |
batch_size = 1 | |
height = 1024 | |
width = 1024 | |
spatial_compression_ratio = 8 | |
patch_size = 1 | |
in_channels = 16 | |
t5_sequence_length = 512 | |
latent_height = height // (spatial_compression_ratio * patch_size) // 2 | |
latent_width = width // (spatial_compression_ratio * patch_size) // 2 | |
guidance_scale = 4.0 | |
latents = torch.randn((batch_size, latent_height * latent_width, in_channels * 2 * 2), dtype=dtype, device=device) | |
pooled_projections = prepare_clip_embeddings(text_encoder, tokenizer, prompt, device, dtype) | |
encoder_hidden_states = prepare_t5_embeddings( | |
text_encoder_2, tokenizer_2, prompt, device, dtype, t5_sequence_length | |
) | |
guidance = torch.full([batch_size], guidance_scale, device=device, dtype=torch.float32) | |
img_ids = prepare_latent_image_ids(latent_height, latent_width, device=device, dtype=dtype) | |
text_ids = torch.zeros(t5_sequence_length, 3).to(device=device, dtype=dtype) | |
ids = torch.cat([text_ids, img_ids], dim=0).float() | |
image_rotary_emb = transformer.pos_embed(ids) | |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
image_seq_len = latents.shape[1] | |
mu = B + image_seq_len * M | |
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** 1.0) | |
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) | |
sigmas = torch.cat([sigmas, sigmas.new_zeros(1)]) | |
sigmas_dtype = sigmas.to(dtype=dtype) | |
dt = sigmas[1:] - sigmas[:-1] | |
print("Warmup step...") | |
for _ in range(2): | |
_ = transformer( | |
hidden_states=latents, | |
encoder_hidden_states=encoder_hidden_states, | |
pooled_projections=pooled_projections, | |
timestep=sigmas_dtype[0].expand(batch_size), | |
image_rotary_emb=image_rotary_emb, | |
guidance=guidance, | |
dt=dt[0], | |
) | |
torch.cuda.synchronize() | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
context = profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True) if enable_profiling else contextlib.nullcontext() | |
with context as ctx: | |
start_event.record() | |
for i in range(num_inference_steps): | |
latents = transformer( | |
hidden_states=latents, | |
encoder_hidden_states=encoder_hidden_states, | |
pooled_projections=pooled_projections, | |
timestep=sigmas_dtype[i].expand(batch_size), | |
image_rotary_emb=image_rotary_emb, | |
guidance=guidance, | |
dt=dt[i], | |
) | |
end_event.record() | |
torch.cuda.synchronize() | |
total_time = start_event.elapsed_time(end_event) / 1000.0 | |
print(f"time: {total_time:.2f}s") | |
if enable_profiling: | |
print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) | |
ctx.export_chrome_trace(f"dump_benchmark_flux_normal_{mode}.json") | |
latents = latents.reshape(-1, latent_height, latent_width, in_channels, 2, 2) | |
latents = latents.permute(0, 3, 1, 4, 2, 5) | |
latents = latents.flatten(4, 5).flatten(2, 3) | |
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor | |
image = vae.decode(latents, return_dict=False)[0] | |
image = image_processor.postprocess(image, output_type="pil")[0] | |
image.save("output.png") | |
@torch.inference_mode() | |
def main_cudagraph( | |
model_id: str, | |
num_inference_steps: int = 28, | |
mode: str = "none", | |
cache_dir: Optional[str] = None, | |
enable_profiling: bool = False, | |
): | |
device = "cuda" | |
dtype = torch.bfloat16 | |
BASE_IMAGE_SEQ_LEN = 256 | |
MAX_IMAGE_SEQ_LEN = 4096 | |
BASE_SHIFT = 0.5 | |
MAX_SHIFT = 1.15 | |
M = (MAX_SHIFT - BASE_SHIFT) / (MAX_IMAGE_SEQ_LEN - BASE_IMAGE_SEQ_LEN) | |
B = BASE_SHIFT - M * BASE_IMAGE_SEQ_LEN | |
prompt = "A cat holding a sign that says 'Hello, World!'" | |
transformer = FluxTransformer2DModel.from_pretrained( | |
model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=dtype | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=dtype | |
) | |
text_encoder_2 = T5EncoderModel.from_pretrained( | |
model_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=dtype | |
) | |
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) | |
tokenizer_2 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) | |
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=dtype) | |
image_processor = VaeImageProcessor(vae_scale_factor=8 * 2) | |
fuse_qkv_(transformer) | |
fuse_adaln_linear_(transformer) | |
[x.to(device) for x in (transformer, text_encoder, text_encoder_2, vae)] | |
if mode != "none": | |
transformer = torch.compile(transformer, mode=mode, fullgraph=True, dynamic=False) | |
batch_size = 1 | |
height = 1024 | |
width = 1024 | |
spatial_compression_ratio = 8 | |
patch_size = 1 | |
in_channels = 16 | |
t5_sequence_length = 512 | |
latent_height = height // (spatial_compression_ratio * patch_size) // 2 | |
latent_width = width // (spatial_compression_ratio * patch_size) // 2 | |
guidance_scale = 4.0 | |
latents = torch.randn((batch_size, latent_height * latent_width, in_channels * 2 * 2), dtype=dtype, device=device) | |
pooled_projections = prepare_clip_embeddings(text_encoder, tokenizer, prompt, device, dtype) | |
encoder_hidden_states = prepare_t5_embeddings( | |
text_encoder_2, tokenizer_2, prompt, device, dtype, t5_sequence_length | |
) | |
guidance = torch.full([batch_size], guidance_scale, device=device, dtype=torch.float32) | |
img_ids = prepare_latent_image_ids(latent_height, latent_width, device=device, dtype=dtype) | |
text_ids = torch.zeros(t5_sequence_length, 3).to(device=device, dtype=dtype) | |
ids = torch.cat([text_ids, img_ids], dim=0) | |
image_rotary_emb = transformer.pos_embed(ids) | |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
image_seq_len = latents.shape[1] | |
mu = B + image_seq_len * M | |
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** 1.0) | |
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) | |
sigmas = torch.cat([sigmas, sigmas.new_zeros(1)]) | |
sigmas_dtype = sigmas.to(dtype=dtype) | |
dt = sigmas[1:] - sigmas[:-1] | |
print("Warmup step...") | |
for _ in range(2): | |
_ = transformer( | |
hidden_states=latents, | |
encoder_hidden_states=encoder_hidden_states, | |
pooled_projections=pooled_projections, | |
timestep=sigmas_dtype[0].expand(batch_size), | |
image_rotary_emb=image_rotary_emb, | |
guidance=guidance, | |
dt=dt[0], | |
) | |
torch.cuda.synchronize() | |
print("Capturing CUDA graph...") | |
graph = torch.cuda.CUDAGraph() | |
static_latents = latents.clone() | |
static_timestep = sigmas_dtype[1].expand(batch_size).clone() | |
static_dt = dt[0].clone() | |
with torch.cuda.graph(graph): | |
static_v = transformer( | |
hidden_states=static_latents, | |
encoder_hidden_states=encoder_hidden_states, | |
pooled_projections=pooled_projections, | |
timestep=static_timestep, | |
image_rotary_emb=image_rotary_emb, | |
guidance=guidance, | |
dt=static_dt, | |
) | |
torch.cuda.synchronize() | |
static_v.copy_(latents) | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
context = profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True) if enable_profiling else contextlib.nullcontext() | |
with context as ctx: | |
start_event.record() | |
for i in range(num_inference_steps): | |
static_latents.copy_(static_v) | |
static_timestep.copy_(sigmas_dtype[i].expand(batch_size)) | |
static_dt.copy_(dt[i]) | |
graph.replay() | |
end_event.record() | |
torch.cuda.synchronize() | |
latents = static_v | |
total_time = start_event.elapsed_time(end_event) / 1000.0 | |
print(f"time: {total_time:.2f}s") | |
if enable_profiling: | |
print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) | |
ctx.export_chrome_trace(f"dump_benchmark_flux_cudagraph_{mode}.json") | |
latents = latents.reshape(-1, latent_height, latent_width, in_channels, 2, 2) | |
latents = latents.permute(0, 3, 1, 4, 2, 5) | |
latents = latents.flatten(4, 5).flatten(2, 3) | |
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor | |
image = vae.decode(latents, return_dict=False)[0] | |
image = image_processor.postprocess(image, output_type="pil")[0] | |
image.save("output.png") | |
if __name__ == "__main__": | |
model_id = "black-forest-labs/FLUX.1-dev" | |
cache_dir = "/raid/.cache/huggingface" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--enable_cudagraph", action="store_true") | |
parser.add_argument("--compile_mode", type=str, default="none", choices=["none", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]) | |
parser.add_argument("--num_inference_steps", type=int, default=28) | |
parser.add_argument("--reduce_rope_precision", action="store_true") | |
parser.add_argument("--enable_profiling", action="store_true") | |
parser.add_argument("--seed", type=int, default=42) | |
args = parser.parse_args() | |
if args.enable_cudagraph and args.compile_mode not in ["none", "default", "max-autotune-no-cudagraphs"]: | |
raise ValueError("CUDAGraph requires compile_mode to be 'default' or 'max-autotune-no-cudagraphs'.") | |
ROPE_PRECISION = torch.bfloat16 if args.reduce_rope_precision else torch.float32 | |
torch.manual_seed(args.seed) | |
if not args.enable_cudagraph: | |
main( | |
model_id=model_id, | |
num_inference_steps=args.num_inference_steps, | |
mode=args.compile_mode, | |
cache_dir=cache_dir, | |
enable_profiling=args.enable_profiling, | |
) | |
else: | |
main_cudagraph( | |
model_id=model_id, | |
num_inference_steps=args.num_inference_steps, | |
mode=args.compile_mode, | |
cache_dir=cache_dir, | |
enable_profiling=args.enable_profiling, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment