Skip to content

Instantly share code, notes, and snippets.

@ljleb
Last active February 10, 2025 19:48
Show Gist options
  • Save ljleb/776a4ef78a8bc1a2f6a411710eb6c920 to your computer and use it in GitHub Desktop.
Save ljleb/776a4ef78a8bc1a2f6a411710eb6c920 to your computer and use it in GitHub Desktop.
SDXL Attention stats
import math
import torch
import torch.nn.functional as F
from typing import Optional
from diffusers import StableDiffusionXLPipeline
from diffusers.models.attention import Attention, BasicTransformerBlock
# ------------------------------------------------------------------------------
# 1. StatsCollector and CustomAttnProcessor2_0 definition
# ------------------------------------------------------------------------------
class StatsCollector:
"""
Container for storing attention statistics and the current diffusion timestep.
"""
def __init__(self):
# Mapping: timestep -> list of stats dictionaries.
self.attention_stats = {}
self.current_timestep = None
class CustomAttnProcessor2_0:
"""
Custom attention processor based on AttnProcessor2_0 that logs statistics
from the attention probabilities (i.e. softmax(QK^T)) computed during scaled
dot-product attention.
This processor replicates the default behavior but computes the attention
scores manually so that statistics can be recorded.
"""
def __init__(self, stats_object: StatsCollector, attn_module_name: str):
self.stats_object = stats_object
self.attn_module_name = attn_module_name
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CustomAttnProcessor2_0 requires PyTorch 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
# (Optional: handle deprecated arguments if needed.)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
else:
# For non-image inputs.
batch_size = hidden_states.shape[0]
# Determine sequence length.
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects the mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
# Reshape query, key, and value to (batch, heads, seq_len, head_dim)
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Compute scaled dot-product attention manually.
scaling = 1.0 / math.sqrt(head_dim)
attn_scores = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
# Compute the attention probabilities.
attention_probs = torch.softmax(attn_scores, dim=-1)
# Log statistics from the attention probabilities.
stats = {
"module": self.attn_module_name,
"mean": attention_probs.mean().item(),
"std": attention_probs.std().item(),
"min": attention_probs.min().item(),
"max": attention_probs.max().item(),
}
# Determine attention type based on whether encoder_hidden_states was provided.
attn_type = "self" if encoder_hidden_states is hidden_states or encoder_hidden_states is None else "cross"
stats["attn_type"] = attn_type
ts = self.stats_object.current_timestep if self.stats_object.current_timestep is not None else "unknown"
if ts not in self.stats_object.attention_stats:
self.stats_object.attention_stats[ts] = []
self.stats_object.attention_stats[ts].append(stats)
# Compute the attention output.
hidden_states = torch.matmul(attention_probs, value)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Linear projection.
hidden_states = attn.to_out[0](hidden_states)
# Dropout.
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
# ------------------------------------------------------------------------------
# 2. Registration: Replace Attn Processors and Install a UNet Pre-hook
# ------------------------------------------------------------------------------
def register_all_attn_processors(unet, stats_object: StatsCollector):
"""
Registers modifications on the UNet:
- A pre-forward hook on the UNet to capture the current diffusion timestep.
- Replacement of each Attention module’s processor (in every BasicTransformerBlock)
with a CustomAttnProcessor2_0 instance that logs attention statistics.
Returns:
patch_handles: A list of hook handles (for the UNet pre-hook).
"""
patch_handles = []
# Pre-hook to record the diffusion timestep.
def unet_pre_hook(module, inputs):
# Assumes UNet.forward() is called with signature:
# (latents, timestep, encoder_hidden_states, ...)
if len(inputs) > 1:
ts = inputs[1]
stats_object.current_timestep = ts.item() if isinstance(ts, torch.Tensor) else ts
else:
stats_object.current_timestep = None
handle = unet.register_forward_pre_hook(unet_pre_hook)
patch_handles.append(handle)
# Iterate over all BasicTransformerBlock modules and replace their processors.
for module_name, module in unet.named_modules():
if isinstance(module, BasicTransformerBlock):
block_name = module_name
# Replace processor for self-attention (attn1), if present.
if hasattr(module, "attn1") and module.attn1 is not None:
processor = CustomAttnProcessor2_0(stats_object, block_name + ".attn1")
if hasattr(module.attn1, "set_processor"):
module.attn1.set_processor(processor)
else:
module.attn1.processor = processor
# Replace processor for cross-attention (attn2), if present.
if hasattr(module, "attn2") and module.attn2 is not None:
processor = CustomAttnProcessor2_0(stats_object, block_name + ".attn2")
if hasattr(module.attn2, "set_processor"):
module.attn2.set_processor(processor)
else:
module.attn2.processor = processor
return patch_handles
# ------------------------------------------------------------------------------
# 3. Image Generation and Statistics Display
# ------------------------------------------------------------------------------
def generate_image(pipeline):
"""
Runs diffusion inference to generate an image.
"""
prompt = "A painting of a futuristic cityscape"
num_inference_steps = 50
print("Running diffusion inference (this may take a while)...")
result = pipeline(prompt, num_inference_steps=num_inference_steps)
return result.images[0]
def display_save_stats(stats_object: StatsCollector):
"""
Displays the collected attention statistics.
"""
print("\nAttention statistics per diffusion timestep:")
for ts in sorted(stats_object.attention_stats, key=lambda t: float(t) if t != "unknown" else -1):
stats_list = stats_object.attention_stats[ts]
print(f" Timestep {ts}: {len(stats_list)} attention activations recorded.")
for stat in stats_list:
print(
f" Module {stat['module']} ({stat['attn_type']} attention): "
f"mean={stat['mean']:.4f}, std={stat['std']:.4f}, "
f"min={stat['min']:.4f}, max={stat['max']:.4f}"
)
def main():
# -------------------------------
# Load the Stable Diffusion pipeline and prepare the UNet.
# -------------------------------
# Note: use `from_single_file` for local SGM models
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline.to(device)
pipeline.unet.eval()
# -------------------------------
# Create a shared stats object and install our minimal patch.
# -------------------------------
stats_object = StatsCollector()
patch_handles = register_all_attn_processors(pipeline.unet, stats_object)
# -------------------------------
# Generate the image.
# -------------------------------
image = generate_image(pipeline)
image.save("generated_image.png")
print("\nImage saved as 'generated_image.png'.")
# -------------------------------
# Display the collected attention statistics.
# -------------------------------
display_save_stats(stats_object)
# (Optional) Remove hooks after inference.
for handle in patch_handles:
handle.remove()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment