Last active
March 7, 2024 09:20
-
-
Save UmerHA/b65bb5fb9626c9c73f3ade2869e36164 to your computer and use it in GitHub Desktop.
Show averaged attention maps for the Prompt2PromptPipeline in 🤗 diffusers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# based on github.com/weifeng-Chen/prompt2prompt/ - shoutout to Weifeng Cheng! | |
# Create pipeline and run it once | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from diffusers.pipelines import Prompt2PromptPipeline | |
pipe = Prompt2PromptPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda") | |
prompts = ["A turtle playing with a ball", | |
"A monkey playing with a ball"] | |
cross_attention_kwargs = { | |
"edit_type": "replace", | |
"cross_replace_steps": 0.4, | |
"self_replace_steps": 0.4 | |
} | |
outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=50, cross_attention_kwargs=cross_attention_kwargs) | |
# Show average attention maps | |
from PIL import Image | |
import cv2 | |
def show_cross_attention(pipe, prompts, res: int, from_where, select: int = 0): | |
tokens = pipe.tokenizer.encode(prompts[select]) | |
decoder = pipe.tokenizer.decode | |
attention_maps = aggregate_attention(prompts, pipe.controller, res, from_where, True, select) | |
images = [] | |
for i in range(len(tokens)): | |
image = attention_maps[:, :, i] | |
image = 255 * image / image.max() | |
image = image.unsqueeze(-1).expand(*image.shape, 3) | |
image = image.numpy().astype(np.uint8) | |
image = np.array(Image.fromarray(image).resize((256, 256))) | |
image = text_under_image(image, decoder(int(tokens[i]))) | |
images.append(image) | |
view_images(np.stack(images, axis=0)) | |
def aggregate_attention(prompts, attention_store, res: int, from_where, is_cross: bool, select: int): | |
out = [] | |
attention_maps = attention_store.get_average_attention() | |
num_pixels = res ** 2 | |
for location in from_where: | |
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: | |
if item.shape[1] == num_pixels: | |
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] | |
out.append(cross_maps) | |
out = torch.cat(out, dim=0) | |
out = out.sum(0) / out.shape[0] | |
return out.cpu() | |
def text_under_image(image, text: str, text_color = (0, 0, 0)): | |
h, w, c = image.shape | |
offset = int(h * .2) | |
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
img[:h] = image | |
textsize = cv2.getTextSize(text, font, 1, 2)[0] | |
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 | |
cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) | |
return img | |
def view_images(images, | |
num_rows: int = 1, | |
offset_ratio: float = 0.02, | |
display_image: bool = True): | |
""" Displays a list of images in a grid. """ | |
if type(images) is list: | |
num_empty = len(images) % num_rows | |
elif images.ndim == 4: | |
num_empty = images.shape[0] % num_rows | |
else: | |
images = [images] | |
num_empty = 0 | |
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 | |
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty | |
num_items = len(images) | |
h, w, c = images[0].shape | |
offset = int(h * offset_ratio) | |
num_cols = num_items // num_rows | |
image_ = np.ones((h * num_rows + offset * (num_rows - 1), | |
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 | |
for i in range(num_rows): | |
for j in range(num_cols): | |
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ | |
i * num_cols + j] | |
pil_img = Image.fromarray(image_) | |
if display_image: | |
display(pil_img) | |
return pil_img | |
show_cross_attention(pipe, prompts, res=16, from_where=("up", "down"), select=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment