Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save a-r-r-o-w/da4f576055411afa82ef8bf1d1334a43 to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/da4f576055411afa82ef8bf1d1334a43 to your computer and use it in GitHub Desktop.
ring attention when you forget to do the rotations
import argparse
import contextlib
import math
import pathlib
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.profiler._utils
import torch._dynamo.config
import torch._inductor.config
import torch._higher_order_ops.auto_functionalize as af
import triton
import triton.language as tl
from torch.distributed.tensor import DTensor, Shard
from torch.profiler import profile, record_function, ProfilerActivity
try:
from flash_attn import flash_attn_func
except:
print("Flash Attention 2 not found.")
try:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
except:
print("Flash Attention 3 not found.")
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.attention import FeedForward
from diffusers.models.embeddings import (
get_1d_rotary_pos_embed,
TimestepEmbedding,
Timesteps,
PixArtAlphaTextProjection,
)
from diffusers.models.cache_utils import CacheMixin
from diffusers.models.modeling_utils import ModelMixin
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
torch._dynamo.config.inline_inbuilt_nn_modules = False
torch._inductor.config.coordinate_descent_tuning = True
# torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.coordinate_descent_search_radius = 1
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.fx_graph_cache = True
# torch._inductor.config.max_fusion_size = 128
# torch._inductor.config.max_pointwise_cat_inputs = 16
# torch._inductor.config.force_pointwise_cat = True
torch._inductor.config.aggressive_fusion = True
# torch._inductor.config.triton.unique_kernel_names = True
af.auto_functionalized_v2._cacheable = True
af.auto_functionalized._cacheable = True
ROPE_PRECISION = torch.float32
SUPPORTED_GUIDANCE_SCALES = [i / 2 for i in range(41)] # 0, 0.5, 1.0, ..., 20.0
CP_MESH = None
# Constants
DEVICE = "cuda"
DTYPE = torch.bfloat16
MIN_STEPS = 2
MAX_STEPS = 50
SPATIAL_COMPRESSION_RATIO = 8
PATCH_SIZE = 1
IN_CHANNELS = 16
BASE_IMAGE_SEQ_LEN = 256
MAX_IMAGE_SEQ_LEN = 4096
BASE_SHIFT = 0.5
MAX_SHIFT = 1.15
M = (MAX_SHIFT - BASE_SHIFT) / (MAX_IMAGE_SEQ_LEN - BASE_IMAGE_SEQ_LEN)
B = BASE_SHIFT - M * BASE_IMAGE_SEQ_LEN
# The following parameters are fixed for now
BATCH_SIZE = 1
HEIGHT = 1024
WIDTH = 1024
LATENT_HEIGHT = HEIGHT // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2
LATENT_WIDTH = WIDTH // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2
T5_SEQUENCE_LENGTH = 512
GUIDANCE_SCALE = 4.0
def _attention_torch_cudnn(query, key, value):
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
out = out.transpose(1, 2)
return out
def _attention_flash_attn_2(query, key, value):
return flash_attn_func(query, key, value)
# For fullgraph=True tracing to be compatible
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
out, lse = flash_attn_3_func(query, key, value)
return out
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
return torch.empty_like(query)
def _attention_flash_attn_3(query, key, value):
out = _wrapped_flash_attn_3(query, key, value)
return out
ATTENTION_OP = _attention_torch_cudnn
class EquipartitionSharder:
@classmethod
@torch.compiler.disable
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
assert tensor.size()[dim] % mesh.size() == 0
return tensor.chunk(mesh.size(), dim=dim)[mesh.get_local_rank()]
@classmethod
@torch.compiler.disable
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
result = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor()
return result
@torch.compile
def pointwise_add3_silu(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(x + y + z)
class AdaLayerNormContinuous(nn.Module):
def __init__(
self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine=True, eps=1e-5, bias=True
):
super().__init__()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
self.norm = torch.nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
emb = self.linear(emb)
scale, shift = emb.unsqueeze(1).chunk(2, dim=-1)
norm_x = self.norm(x)
x = torch.addcmul(shift, norm_x, 1 + scale)
return x
class AdaLayerNormZeroSingle(nn.Module):
def __init__(self, embedding_dim: int, bias=True):
super().__init__()
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None):
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
norm_x = self.norm(x)
x = torch.addcmul(shift_msa, norm_x, 1 + scale_msa)
return x, gate_msa
class AdaLayerNormZero(nn.Module):
def __init__(self, embedding_dim: int, bias=True):
super().__init__()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x: torch.Tensor, emb: torch.Tensor):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
norm_x = self.norm(x)
x = torch.addcmul(shift_msa, norm_x, 1 + scale_msa)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, t_emb, guidance_emb, pooled_projection_emb):
# We precompute timestep_emb, guidance_emb and pooled_projection_emb only once outside the forward
raise NotImplementedError("The implementation precomputes required embeddings, so this should be unreachable.")
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
):
super().__init__()
assert qk_norm == "rms_norm", "Flux uses RMSNorm"
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=out_bias)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
cos, sin = image_rotary_emb if image_rotary_emb is not None else (None, None)
query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1)
query, key, value = (x.unflatten(2, (self.heads, -1)) for x in (query, key, value))
query = self.norm_q(query)
key = self.norm_k(key)
if encoder_hidden_states is not None:
query_c, key_c, value_c = self.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
query_c, key_c, value_c = (x.unflatten(2, (self.heads, -1)) for x in (query_c, key_c, value_c))
query_c = self.norm_added_q(query_c)
key_c = self.norm_added_k(key_c)
query = torch.cat([query_c, query], dim=1)
key = torch.cat([key_c, key], dim=1)
value = torch.cat([value_c, value], dim=1)
if image_rotary_emb is not None:
x_real, x_imag = query.unflatten(-1, (-1, 2)).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
query = (query.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(query)
x_real, x_imag = key.unflatten(-1, (-1, 2)).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
key = (key.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(key)
hidden_states = ATTENTION_OP(query, key, value)
hidden_states = hidden_states.flatten(2, 3)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = torch.split_with_sizes(
hidden_states,
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]],
dim=1,
)
hidden_states = self.to_out[0](hidden_states)
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
return hidden_states
@torch.no_grad()
def fuse_projections(self):
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
for layer in ("to_q", "to_k", "to_v", "to_added_q", "to_added_k", "to_added_v"):
if hasattr(self, layer):
module = getattr(self, layer)
module.to("meta")
delattr(self, layer)
self.fused_projections = True
class FluxPosEmbed(nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
ids[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=ROPE_PRECISION,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None]
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None]
return freqs_cos, freqs_sin
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
attn_output = self.attn(hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states = self.proj_out(hidden_states)
hidden_states = torch.addcmul(residual, gate, hidden_states)
return hidden_states
class FluxTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = Attention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
qk_norm=qk_norm,
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
temb, temb_context = temb.chunk(2, dim=-1)
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb_context
)
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = torch.addcmul(hidden_states, gate_msa, attn_output)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_msa, context_attn_output)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_hidden_states = torch.addcmul(shift_mlp, norm_hidden_states, 1 + scale_mlp)
norm_encoder_hidden_states = torch.addcmul(c_shift_mlp, norm_encoder_hidden_states, 1 + c_scale_mlp)
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states = torch.addcmul(hidden_states, gate_mlp, ff_output)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_mlp, context_ff_output)
return encoder_hidden_states, hidden_states
class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
):
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
if not guidance_embeds:
raise ValueError("FLUX.1-schnell inference is not yet supported.")
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim
)
for _ in range(num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim
)
for _ in range(num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
conditioning: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
dt: torch.Tensor,
) -> torch.Tensor:
x_t = hidden_states
if CP_MESH is not None:
hidden_states = EquipartitionSharder.shard(hidden_states, 1, CP_MESH)
encoder_hidden_states = EquipartitionSharder.shard(encoder_hidden_states, 1, CP_MESH)
image_rotary_emb = (
EquipartitionSharder.shard(image_rotary_emb[0], 1, CP_MESH),
EquipartitionSharder.shard(image_rotary_emb[1], 1, CP_MESH),
)
hidden_states = self.x_embedder(hidden_states)
adaln_linear_states = self.adaln_linear(conditioning).unsqueeze(1).chunk(self.config.num_layers, dim=-1)
adaln_linear_single_states = (
self.adaln_linear_single(conditioning).unsqueeze(1).chunk(self.config.num_single_layers, dim=-1)
)
for i, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=adaln_linear_states[i],
image_rotary_emb=image_rotary_emb,
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for i, block in enumerate(self.single_transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
temb=adaln_linear_single_states[i],
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, conditioning)
v = self.proj_out(hidden_states)
if CP_MESH is not None:
v = EquipartitionSharder.unshard(v, 1, CP_MESH)
x = x_t + dt * v
return x
@torch.no_grad()
def fuse_qkv_(model: FluxTransformer2DModel) -> FluxTransformer2DModel:
for submodule in model.modules():
if not isinstance(submodule, Attention):
continue
submodule.fuse_projections()
@torch.no_grad()
def fuse_adaln_linear_(model: FluxTransformer2DModel) -> FluxTransformer2DModel:
adaln_linear_weights = []
adaln_linear_biases = []
for block in model.transformer_blocks:
adaln_linear_weights.append(block.norm1.linear.weight.data.clone())
adaln_linear_weights.append(block.norm1_context.linear.weight.data.clone())
adaln_linear_biases.append(block.norm1.linear.bias.data.clone())
adaln_linear_biases.append(block.norm1_context.linear.bias.data.clone())
block.norm1.linear.to("meta")
block.norm1_context.linear.to("meta")
del block.norm1.linear, block.norm1_context.linear
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0)
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0)
in_features = adaln_linear_weights.shape[1]
out_features = adaln_linear_weights.shape[0]
model.adaln_linear = torch.nn.Linear(
in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype
)
model.adaln_linear.weight.copy_(adaln_linear_weights)
model.adaln_linear.bias.copy_(adaln_linear_biases)
adaln_linear_weights = []
adaln_linear_biases = []
for block in model.single_transformer_blocks:
adaln_linear_weights.append(block.norm.linear.weight.data.clone())
adaln_linear_biases.append(block.norm.linear.bias.data.clone())
block.norm.linear.to("meta")
del block.norm.linear
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0)
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0)
in_features = adaln_linear_weights.shape[1]
out_features = adaln_linear_weights.shape[0]
model.adaln_linear_single = torch.nn.Linear(
in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype
)
model.adaln_linear_single.weight.copy_(adaln_linear_weights)
model.adaln_linear_single.bias.copy_(adaln_linear_biases)
def prepare_clip_embeddings(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
prompt: str,
device: torch.device,
dtype: torch.dtype,
max_length: int = 77,
) -> torch.Tensor:
prompt = [prompt]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
prompt_embeds = prompt_embeds.pooler_output.to(dtype)
return prompt_embeds
def prepare_t5_embeddings(
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
prompt: str,
device: torch.device,
dtype: torch.dtype,
max_length: int = 512,
) -> torch.Tensor:
prompt = [prompt]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
prompt_embeds = prompt_embeds.to(dtype)
return prompt_embeds
def prepare_latent_image_ids(height: int, width: int, device: torch.device, dtype: torch.dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
).contiguous()
return latent_image_ids.to(device=device, dtype=dtype)
def precompute_guidance_embeds(transformer: FluxTransformer2DModel):
embeds = {}
for guidance_scale in SUPPORTED_GUIDANCE_SCALES:
guidance = torch.full([1], guidance_scale, device="cuda", dtype=torch.float32)
guidance = transformer.time_text_embed.guidance_embedder(
transformer.time_text_embed.time_proj(guidance * 1000.0).to(dtype=torch.bfloat16)
)
embeds[f"{guidance_scale:.1f}"] = guidance
return embeds
def precompute_timestep_embeds(transformer: FluxTransformer2DModel):
embeds = {}
image_seq_len = LATENT_HEIGHT * LATENT_WIDTH
for num_inference_steps in range(MIN_STEPS, MAX_STEPS + 1):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = B + image_seq_len * M
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** 1.0)
sigmas = torch.from_numpy(sigmas).to(DEVICE, dtype=torch.float32)
sigmas = torch.cat([sigmas, sigmas.new_zeros(1)])
timesteps = (sigmas * 1000.0).to(DTYPE)
temb = transformer.time_text_embed.time_proj(timesteps)
temb = transformer.time_text_embed.timestep_embedder(temb.to(DTYPE))
embeds[num_inference_steps] = (sigmas, temb)
return embeds
def precompute_embeds(transformer: FluxTransformer2DModel, save_dir: str):
save_dir = pathlib.Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
guidance_path = save_dir / "guidance_embeds.pt"
timestep_path = save_dir / "timestep_embeds.pt"
if guidance_path.exists():
guidance_embeds = torch.load(guidance_path, map_location=DEVICE, weights_only=True)
print(f'Loaded precomputed guidance embeddings from "{guidance_path.as_posix()}"')
else:
guidance_embeds = precompute_guidance_embeds(transformer)
if CP_MESH is None or CP_MESH.get_local_rank() == 0:
torch.save(guidance_embeds, guidance_path.as_posix())
print(f'Precomputed guidance embeddings saved to "{save_dir.as_posix()}"')
if timestep_path.exists():
timestep_embeds = torch.load(timestep_path, map_location=DEVICE, weights_only=True)
print(f'Loaded precomputed timestep embeddings from "{timestep_path.as_posix()}"')
else:
timestep_embeds = precompute_timestep_embeds(transformer)
if CP_MESH is None or CP_MESH.get_local_rank() == 0:
torch.save(timestep_embeds, timestep_path.as_posix())
print(f'Precomputed timestep embeddings saved to "{save_dir.as_posix()}"')
return guidance_embeds, timestep_embeds
def capture_cudagraph(
model: FluxTransformer2DModel,
latents: torch.Tensor,
encoder_hidden_states: torch.Tensor,
conditioning: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
dt: torch.Tensor,
):
print("Capturing CUDA graph...")
static_latents = latents.clone()
static_conditioning = conditioning.clone()
static_dt = dt.clone()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
static_x = model(
hidden_states=static_latents,
encoder_hidden_states=encoder_hidden_states,
conditioning=static_conditioning,
image_rotary_emb=image_rotary_emb,
dt=static_dt,
)
return graph, static_latents, static_conditioning, static_dt, static_x
@torch.inference_mode()
@torch.compiler.set_stance("default", skip_guard_eval_unsafe=True)
def main(
model_id: str,
working_dir: str = "/tmp/flux_precomputation",
num_inference_steps: int = 28,
mode: str = "none",
cache_dir: Optional[str] = None,
enable_cudagraph: bool = False,
enable_profiling: bool = False,
):
prompt = "A cat holding a sign that says 'Hello, World!'"
transformer = FluxTransformer2DModel.from_pretrained(
model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=DTYPE
)
text_encoder = CLIPTextModel.from_pretrained(
model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=DTYPE
)
text_encoder_2 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=DTYPE
)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
tokenizer_2 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=DTYPE)
image_processor = VaeImageProcessor(vae_scale_factor=8 * 2)
fuse_qkv_(transformer)
fuse_adaln_linear_(transformer)
[x.to(DEVICE) for x in (transformer, text_encoder, text_encoder_2, vae)]
if mode != "none":
transformer = torch.compile(transformer, mode=mode, fullgraph=True, dynamic=False)
guidance_embeds, timestep_embeds = precompute_embeds(transformer, working_dir)
latents = torch.randn((BATCH_SIZE, LATENT_HEIGHT * LATENT_WIDTH, IN_CHANNELS * 2 * 2), dtype=DTYPE, device=DEVICE)
pooled_projections = prepare_clip_embeddings(text_encoder, tokenizer, prompt, DEVICE, DTYPE)
encoder_hidden_states = prepare_t5_embeddings(
text_encoder_2, tokenizer_2, prompt, DEVICE, DTYPE, T5_SEQUENCE_LENGTH
)
# <precompute>
guidance_conditioning = guidance_embeds[f"{GUIDANCE_SCALE:.1f}"]
sigmas, timestep_conditioning = timestep_embeds[num_inference_steps]
pooled_projections = transformer.time_text_embed.text_embedder(pooled_projections)
encoder_hidden_states = transformer.context_embedder(encoder_hidden_states)
# </precompute>
img_ids = prepare_latent_image_ids(LATENT_HEIGHT, LATENT_WIDTH, device=DEVICE, dtype=DTYPE)
text_ids = torch.zeros(T5_SEQUENCE_LENGTH, 3).to(device=DEVICE, dtype=DTYPE)
ids = torch.cat([text_ids, img_ids], dim=0).float()
image_rotary_emb = transformer.pos_embed(ids)
dt = sigmas[1:] - sigmas[:-1]
print("Warmup step...")
for _ in range(2):
conditioning = pointwise_add3_silu(timestep_conditioning[0, :], guidance_conditioning, pooled_projections)
_ = transformer(
hidden_states=latents,
encoder_hidden_states=encoder_hidden_states,
conditioning=conditioning,
image_rotary_emb=image_rotary_emb,
dt=dt[0],
)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
context = (
profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True)
if enable_profiling
else contextlib.nullcontext()
)
if not enable_cudagraph:
with context as ctx:
start_event.record()
for i in range(num_inference_steps):
conditioning = pointwise_add3_silu(
timestep_conditioning[i, :], guidance_conditioning, pooled_projections
)
latents = transformer(
hidden_states=latents,
encoder_hidden_states=encoder_hidden_states,
conditioning=conditioning,
image_rotary_emb=image_rotary_emb,
dt=dt[i],
)
end_event.record()
torch.cuda.synchronize()
else:
graph, static_latents, static_conditioning, static_dt, static_x = capture_cudagraph(
transformer,
latents,
encoder_hidden_states,
pooled_projections + guidance_conditioning,
image_rotary_emb,
dt[0],
)
static_x.copy_(latents)
with context as ctx:
start_event.record()
for i in range(num_inference_steps):
conditioning = pointwise_add3_silu(
timestep_conditioning[i, :], guidance_conditioning, pooled_projections
)
static_latents.copy_(static_x)
static_conditioning.copy_(conditioning)
static_dt.copy_(dt[i])
graph.replay()
end_event.record()
torch.cuda.synchronize()
latents = static_x
total_time = start_event.elapsed_time(end_event) / 1000.0
print(f"time: {total_time:.5f}s")
if enable_profiling:
print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
ctx.export_chrome_trace(f"dump_benchmark_flux_normal_{mode}.json")
if CP_MESH is None or CP_MESH.get_local_rank() == 0:
latents = latents.reshape(-1, LATENT_HEIGHT, LATENT_WIDTH, IN_CHANNELS, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.flatten(4, 5).flatten(2, 3)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="pil")[0]
image.save("output.png")
if __name__ == "__main__":
model_id = "black-forest-labs/FLUX.1-dev"
cache_dir = None
parser = argparse.ArgumentParser()
parser.add_argument("--cp_degree", type=int, default=1, help="Degree of context parallelism.")
parser.add_argument("--working_dir", type=str, default="/tmp/flux_precomputation")
parser.add_argument("--enable_cudagraph", action="store_true")
parser.add_argument(
"--compile_mode",
type=str,
default="none",
choices=["none", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],
)
parser.add_argument("--num_inference_steps", type=int, default=28)
parser.add_argument("--reduce_rope_precision", action="store_true")
parser.add_argument("--attention_backend", type=str, default="torch_cudnn", choices=["torch_cudnn", "fa2", "fa3"])
parser.add_argument("--enable_profiling", action="store_true")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
if args.enable_cudagraph and args.compile_mode not in ["none", "default", "max-autotune-no-cudagraphs"]:
raise ValueError("CUDAGraph requires compile_mode to be 'default' or 'max-autotune-no-cudagraphs'.")
if args.num_inference_steps > 50 or args.num_inference_steps < 2:
raise ValueError("num_inference_steps must be between 2 and 50.")
ROPE_PRECISION = torch.bfloat16 if args.reduce_rope_precision else torch.float32
if args.attention_backend == "fa2":
ATTENTION_OP = _attention_flash_attn_2
elif args.attention_backend == "fa3":
ATTENTION_OP = _attention_flash_attn_3
torch.manual_seed(args.seed)
try:
if args.cp_degree > 1:
dist.init_process_group("nccl")
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.cuda.set_device(rank)
CP_MESH = dist.device_mesh.init_device_mesh("cuda", [world_size], mesh_dim_names=["cp"])
if args.enable_profiling and args.enable_cudagraph:
torch.profiler._utils._init_for_cuda_graphs()
main(
model_id=model_id,
working_dir=args.working_dir,
num_inference_steps=args.num_inference_steps,
mode=args.compile_mode,
cache_dir=cache_dir,
enable_cudagraph=args.enable_cudagraph,
enable_profiling=args.enable_profiling,
)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@a-r-r-o-w
Copy link
Author

image

@a-r-r-o-w
Copy link
Author

what

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment