Created
March 26, 2024 11:36
-
-
Save pablomm/984994d3a5b671e7228f308158534fd6 to your computer and use it in GitHub Desktop.
Img2Img Ovam example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from diffusers import StableDiffusionImg2ImgPipeline | |
from ovam import StableDiffusionHooker | |
from ovam.utils import set_seed | |
import requests | |
from io import BytesIO | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
device = "cuda" | |
model_id_or_path = "runwayml/stable-diffusion-v1-5" | |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) | |
pipe = pipe.to(device) | |
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" | |
response = requests.get(url) | |
init_image = Image.open(BytesIO(response.content)).convert("RGB") | |
init_image = init_image.resize((512, 512)) | |
prompt = "A fantasy landscape, trending on artstation" | |
# Generate image storing the internal attentions | |
set_seed(1) | |
with StableDiffusionHooker(pipe) as hooker: | |
image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0] | |
# Evaluator of attention | |
ovam_evaluator = hooker.get_ovam_callable(expand_size=(512, 512)) | |
with torch.no_grad(): | |
attention_maps = ovam_evaluator("castle") | |
attention_maps = attention_maps[0].cpu().numpy() # (3, 512, 512) | |
attribution_prompt = "A castle" | |
with torch.no_grad(): | |
attention_maps = ovam_evaluator(attribution_prompt) | |
attention_maps = attention_maps[0].cpu().numpy() # (3, 512, 512) | |
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(10, 4)) | |
ax0.imshow(init_image) | |
ax1.imshow(image) | |
ax2.imshow(image) | |
# Normalize attentions [0, 1] to use as alpha in plot | |
castle_attention = attention_maps[2] # Castle is the third word <SoT> a castle<EoT> | |
castle_attention = (castle_attention - castle_attention.min()) | |
castle_attention = castle_attention / castle_attention.max() | |
ax2.imshow(castle_attention, alpha=castle_attention.astype(float), cmap='jet') | |
ax0.set_title("Init image") | |
ax0.axis('off') | |
ax1.set_title("Generated image") | |
ax1.axis('off') | |
ax2.set_title("+ Castle attentions") | |
ax2.axis('off') | |
fig.savefig("example_img2img.jpg") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment