Last active
February 10, 2025 19:48
-
-
Save ljleb/776a4ef78a8bc1a2f6a411710eb6c920 to your computer and use it in GitHub Desktop.
SDXL Attention stats
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import torch.nn.functional as F | |
from typing import Optional | |
from diffusers import StableDiffusionXLPipeline | |
from diffusers.models.attention import Attention, BasicTransformerBlock | |
# ------------------------------------------------------------------------------ | |
# 1. StatsCollector and CustomAttnProcessor2_0 definition | |
# ------------------------------------------------------------------------------ | |
class StatsCollector: | |
""" | |
Container for storing attention statistics and the current diffusion timestep. | |
""" | |
def __init__(self): | |
# Mapping: timestep -> list of stats dictionaries. | |
self.attention_stats = {} | |
self.current_timestep = None | |
class CustomAttnProcessor2_0: | |
""" | |
Custom attention processor based on AttnProcessor2_0 that logs statistics | |
from the attention probabilities (i.e. softmax(QK^T)) computed during scaled | |
dot-product attention. | |
This processor replicates the default behavior but computes the attention | |
scores manually so that statistics can be recorded. | |
""" | |
def __init__(self, stats_object: StatsCollector, attn_module_name: str): | |
self.stats_object = stats_object | |
self.attn_module_name = attn_module_name | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError("CustomAttnProcessor2_0 requires PyTorch 2.0.") | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
temb: Optional[torch.Tensor] = None, | |
*args, | |
**kwargs, | |
) -> torch.Tensor: | |
# (Optional: handle deprecated arguments if needed.) | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
else: | |
# For non-image inputs. | |
batch_size = hidden_states.shape[0] | |
# Determine sequence length. | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# scaled_dot_product_attention expects the mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
# Reshape query, key, and value to (batch, heads, seq_len, head_dim) | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
if attn.norm_q is not None: | |
query = attn.norm_q(query) | |
if attn.norm_k is not None: | |
key = attn.norm_k(key) | |
# Compute scaled dot-product attention manually. | |
scaling = 1.0 / math.sqrt(head_dim) | |
attn_scores = torch.matmul(query, key.transpose(-1, -2)) * scaling | |
if attention_mask is not None: | |
attn_scores = attn_scores + attention_mask | |
# Compute the attention probabilities. | |
attention_probs = torch.softmax(attn_scores, dim=-1) | |
# Log statistics from the attention probabilities. | |
stats = { | |
"module": self.attn_module_name, | |
"mean": attention_probs.mean().item(), | |
"std": attention_probs.std().item(), | |
"min": attention_probs.min().item(), | |
"max": attention_probs.max().item(), | |
} | |
# Determine attention type based on whether encoder_hidden_states was provided. | |
attn_type = "self" if encoder_hidden_states is hidden_states or encoder_hidden_states is None else "cross" | |
stats["attn_type"] = attn_type | |
ts = self.stats_object.current_timestep if self.stats_object.current_timestep is not None else "unknown" | |
if ts not in self.stats_object.attention_stats: | |
self.stats_object.attention_stats[ts] = [] | |
self.stats_object.attention_stats[ts].append(stats) | |
# Compute the attention output. | |
hidden_states = torch.matmul(attention_probs, value) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# Linear projection. | |
hidden_states = attn.to_out[0](hidden_states) | |
# Dropout. | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
# ------------------------------------------------------------------------------ | |
# 2. Registration: Replace Attn Processors and Install a UNet Pre-hook | |
# ------------------------------------------------------------------------------ | |
def register_all_attn_processors(unet, stats_object: StatsCollector): | |
""" | |
Registers modifications on the UNet: | |
- A pre-forward hook on the UNet to capture the current diffusion timestep. | |
- Replacement of each Attention module’s processor (in every BasicTransformerBlock) | |
with a CustomAttnProcessor2_0 instance that logs attention statistics. | |
Returns: | |
patch_handles: A list of hook handles (for the UNet pre-hook). | |
""" | |
patch_handles = [] | |
# Pre-hook to record the diffusion timestep. | |
def unet_pre_hook(module, inputs): | |
# Assumes UNet.forward() is called with signature: | |
# (latents, timestep, encoder_hidden_states, ...) | |
if len(inputs) > 1: | |
ts = inputs[1] | |
stats_object.current_timestep = ts.item() if isinstance(ts, torch.Tensor) else ts | |
else: | |
stats_object.current_timestep = None | |
handle = unet.register_forward_pre_hook(unet_pre_hook) | |
patch_handles.append(handle) | |
# Iterate over all BasicTransformerBlock modules and replace their processors. | |
for module_name, module in unet.named_modules(): | |
if isinstance(module, BasicTransformerBlock): | |
block_name = module_name | |
# Replace processor for self-attention (attn1), if present. | |
if hasattr(module, "attn1") and module.attn1 is not None: | |
processor = CustomAttnProcessor2_0(stats_object, block_name + ".attn1") | |
if hasattr(module.attn1, "set_processor"): | |
module.attn1.set_processor(processor) | |
else: | |
module.attn1.processor = processor | |
# Replace processor for cross-attention (attn2), if present. | |
if hasattr(module, "attn2") and module.attn2 is not None: | |
processor = CustomAttnProcessor2_0(stats_object, block_name + ".attn2") | |
if hasattr(module.attn2, "set_processor"): | |
module.attn2.set_processor(processor) | |
else: | |
module.attn2.processor = processor | |
return patch_handles | |
# ------------------------------------------------------------------------------ | |
# 3. Image Generation and Statistics Display | |
# ------------------------------------------------------------------------------ | |
def generate_image(pipeline): | |
""" | |
Runs diffusion inference to generate an image. | |
""" | |
prompt = "A painting of a futuristic cityscape" | |
num_inference_steps = 50 | |
print("Running diffusion inference (this may take a while)...") | |
result = pipeline(prompt, num_inference_steps=num_inference_steps) | |
return result.images[0] | |
def display_save_stats(stats_object: StatsCollector): | |
""" | |
Displays the collected attention statistics. | |
""" | |
print("\nAttention statistics per diffusion timestep:") | |
for ts in sorted(stats_object.attention_stats, key=lambda t: float(t) if t != "unknown" else -1): | |
stats_list = stats_object.attention_stats[ts] | |
print(f" Timestep {ts}: {len(stats_list)} attention activations recorded.") | |
for stat in stats_list: | |
print( | |
f" Module {stat['module']} ({stat['attn_type']} attention): " | |
f"mean={stat['mean']:.4f}, std={stat['std']:.4f}, " | |
f"min={stat['min']:.4f}, max={stat['max']:.4f}" | |
) | |
def main(): | |
# ------------------------------- | |
# Load the Stable Diffusion pipeline and prepare the UNet. | |
# ------------------------------- | |
# Note: use `from_single_file` for local SGM models | |
pipeline = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipeline.to(device) | |
pipeline.unet.eval() | |
# ------------------------------- | |
# Create a shared stats object and install our minimal patch. | |
# ------------------------------- | |
stats_object = StatsCollector() | |
patch_handles = register_all_attn_processors(pipeline.unet, stats_object) | |
# ------------------------------- | |
# Generate the image. | |
# ------------------------------- | |
image = generate_image(pipeline) | |
image.save("generated_image.png") | |
print("\nImage saved as 'generated_image.png'.") | |
# ------------------------------- | |
# Display the collected attention statistics. | |
# ------------------------------- | |
display_save_stats(stats_object) | |
# (Optional) Remove hooks after inference. | |
for handle in patch_handles: | |
handle.remove() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment