Skip to content

Instantly share code, notes, and snippets.

@UmerHA
Last active March 7, 2024 09:20
Show Gist options
  • Save UmerHA/b65bb5fb9626c9c73f3ade2869e36164 to your computer and use it in GitHub Desktop.
Save UmerHA/b65bb5fb9626c9c73f3ade2869e36164 to your computer and use it in GitHub Desktop.
Show averaged attention maps for the Prompt2PromptPipeline in 🤗 diffusers
# based on github.com/weifeng-Chen/prompt2prompt/ - shoutout to Weifeng Cheng!
# Create pipeline and run it once
import torch
import numpy as np
import matplotlib.pyplot as plt
from diffusers.pipelines import Prompt2PromptPipeline
pipe = Prompt2PromptPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda")
prompts = ["A turtle playing with a ball",
"A monkey playing with a ball"]
cross_attention_kwargs = {
"edit_type": "replace",
"cross_replace_steps": 0.4,
"self_replace_steps": 0.4
}
outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=50, cross_attention_kwargs=cross_attention_kwargs)
# Show average attention maps
from PIL import Image
import cv2
def show_cross_attention(pipe, prompts, res: int, from_where, select: int = 0):
tokens = pipe.tokenizer.encode(prompts[select])
decoder = pipe.tokenizer.decode
attention_maps = aggregate_attention(prompts, pipe.controller, res, from_where, True, select)
images = []
for i in range(len(tokens)):
image = attention_maps[:, :, i]
image = 255 * image / image.max()
image = image.unsqueeze(-1).expand(*image.shape, 3)
image = image.numpy().astype(np.uint8)
image = np.array(Image.fromarray(image).resize((256, 256)))
image = text_under_image(image, decoder(int(tokens[i])))
images.append(image)
view_images(np.stack(images, axis=0))
def aggregate_attention(prompts, attention_store, res: int, from_where, is_cross: bool, select: int):
out = []
attention_maps = attention_store.get_average_attention()
num_pixels = res ** 2
for location in from_where:
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
if item.shape[1] == num_pixels:
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
out.append(cross_maps)
out = torch.cat(out, dim=0)
out = out.sum(0) / out.shape[0]
return out.cpu()
def text_under_image(image, text: str, text_color = (0, 0, 0)):
h, w, c = image.shape
offset = int(h * .2)
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
font = cv2.FONT_HERSHEY_SIMPLEX
img[:h] = image
textsize = cv2.getTextSize(text, font, 1, 2)[0]
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
return img
def view_images(images,
num_rows: int = 1,
offset_ratio: float = 0.02,
display_image: bool = True):
""" Displays a list of images in a grid. """
if type(images) is list:
num_empty = len(images) % num_rows
elif images.ndim == 4:
num_empty = images.shape[0] % num_rows
else:
images = [images]
num_empty = 0
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
num_items = len(images)
h, w, c = images[0].shape
offset = int(h * offset_ratio)
num_cols = num_items // num_rows
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
for i in range(num_rows):
for j in range(num_cols):
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
i * num_cols + j]
pil_img = Image.fromarray(image_)
if display_image:
display(pil_img)
return pil_img
show_cross_attention(pipe, prompts, res=16, from_where=("up", "down"), select=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment