Created
September 6, 2022 14:53
-
-
Save afiaka87/683cb3ab5115d192e0ef170651c6f834 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from typing import List, Optional, Union, Tuple | |
import numpy as np | |
import torch | |
from PIL import Image | |
from diffusers import ( | |
AutoencoderKL, | |
DDIMScheduler, | |
DiffusionPipeline, | |
PNDMScheduler, | |
LMSDiscreteScheduler, | |
UNet2DConditionModel, | |
) | |
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
from diffusers.pipelines.stable_diffusion.safety_checker import cosine_distance | |
from tqdm.auto import tqdm | |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | |
def preprocess_init_image(image: Image, width: int, height: int): | |
image = image.resize((width, height), resample=Image.LANCZOS) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image[None].transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image) | |
return 2.0 * image - 1.0 | |
def preprocess_mask(mask: Image, width: int, height: int): | |
mask = mask.convert("L") | |
mask = mask.resize((width // 8, height // 8), resample=Image.LANCZOS) | |
mask = np.array(mask).astype(np.float32) / 255.0 | |
mask = np.tile(mask, (4, 1, 1)) | |
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? | |
mask = torch.from_numpy(mask) | |
return mask | |
class StableDiffusionImg2ImgPipeline(DiffusionPipeline): | |
""" | |
From https://github.com/huggingface/diffusers/pull/241 | |
""" | |
def __init__( | |
self, | |
vae: AutoencoderKL, | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
unet: UNet2DConditionModel, | |
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], | |
safety_checker: StableDiffusionSafetyChecker, | |
feature_extractor: CLIPFeatureExtractor, | |
): | |
super().__init__() | |
scheduler = scheduler.set_format("pt") | |
self.register_modules( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
safety_checker=safety_checker, | |
feature_extractor=feature_extractor, | |
) | |
@torch.no_grad() | |
def __call__( | |
self, | |
prompt: Union[str, List[str]], | |
init_image: Optional[torch.FloatTensor], | |
mask: Optional[torch.FloatTensor], | |
width: int, | |
height: int, | |
prompt_strength: float = 0.8, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 7.5, | |
eta: float = 0.0, | |
generator: Optional[torch.Generator] = None, | |
) -> Image: | |
if isinstance(prompt, str): | |
batch_size = 1 | |
prompt = [prompt] | |
elif isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
raise ValueError( | |
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" | |
) | |
if prompt_strength < 0 or prompt_strength > 1: | |
raise ValueError( | |
f"The value of prompt_strength should in [0.0, 1.0] but is {prompt_strength}" | |
) | |
if mask is not None and init_image is None: | |
raise ValueError( | |
"If mask is defined, then init_image also needs to be defined" | |
) | |
if width % 8 != 0 or height % 8 != 0: | |
raise ValueError("Width and height must both be divisible by 8") | |
# set timesteps | |
accepts_offset = "offset" in set( | |
inspect.signature(self.scheduler.set_timesteps).parameters.keys() | |
) | |
extra_set_kwargs = {} | |
offset = 0 | |
if accepts_offset: | |
offset = 1 | |
extra_set_kwargs["offset"] = 1 | |
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | |
text_embeddings = self.embed_text(prompt=prompt) | |
uncond_embeddings = self.embed_text( | |
prompt=[""] * batch_size | |
) # only need to call this once | |
if init_image is not None: | |
init_latents_orig, latents, init_timestep = self.latents_from_init_image( | |
init_image, | |
prompt_strength, | |
offset, | |
num_inference_steps, | |
batch_size, | |
generator, | |
) | |
else: | |
latents = torch.randn( | |
(batch_size, self.unet.in_channels, height // 8, width // 8), | |
generator=generator, | |
device=self.device, | |
) | |
init_timestep = num_inference_steps | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
if do_classifier_free_guidance: | |
sd_hidden_states = torch.cat([uncond_embeddings, text_embeddings], dim=0) | |
else: | |
sd_hidden_states = text_embeddings | |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
# and should be between [0, 1] | |
accepts_eta = "eta" in set( | |
inspect.signature(self.scheduler.step).parameters.keys() | |
) | |
extra_step_kwargs = {} | |
if accepts_eta: | |
extra_step_kwargs["eta"] = eta | |
mask_noise = torch.randn(latents.shape, generator=generator, device=self.device) | |
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas | |
if isinstance(self.scheduler, LMSDiscreteScheduler): | |
latents = latents * self.scheduler.sigmas[0] | |
t_start = max(num_inference_steps - init_timestep + offset, 0) | |
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = ( | |
torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
) | |
if isinstance(self.scheduler, LMSDiscreteScheduler): | |
sigma = self.scheduler.sigmas[i] | |
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | |
# predict the noise residual | |
noise_pred = self.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=sd_hidden_states, # hidden state/context is the uncond embedding and the text embedding | |
)["sample"] | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
# compute the previous noisy sample x_t -> x_t-1 | |
if isinstance(self.scheduler, LMSDiscreteScheduler): | |
latents = self.scheduler.step( | |
noise_pred, i, latents, **extra_step_kwargs | |
)["prev_sample"] | |
else: | |
latents = self.scheduler.step( | |
noise_pred, t, latents, **extra_step_kwargs | |
)["prev_sample"] | |
# replace the unmasked part with original latents, with added noise | |
if mask is not None: | |
timesteps = self.scheduler.timesteps[t_start + i] | |
timesteps = torch.tensor( | |
[timesteps] * batch_size, dtype=torch.long, device=self.device | |
) | |
noisy_init_latents = self.scheduler.add_noise( | |
init_latents_orig, mask_noise, timesteps | |
) | |
latents = noisy_init_latents * mask + latents * (1 - mask) | |
# scale and decode the image latents with vae | |
latents = 1 / 0.18215 * latents | |
decoded_image = self.vae.decode(latents) | |
decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1) | |
decoded_image = decoded_image.cpu().permute(0,2,3,1).numpy() | |
# only need the "first" prompt since it is repeated batch_size times | |
pooled_text_embeddings = self.embed_text(prompt=[prompt[0]], pooled_output=True) | |
# get pooled image embeddings for cosine similarity (1, 768) | |
clip_image_tensors = self.feature_extractor( | |
self.numpy_to_pil(decoded_image), return_tensors="pt" | |
).to(self.device) | |
clip_image_tensors = clip_image_tensors.pixel_values | |
pooled_image_embeddings = self.image_to_embedding(image_tensors=clip_image_tensors) | |
scores = ( | |
cosine_distance(pooled_image_embeddings, pooled_text_embeddings) | |
).detach().cpu().numpy() | |
pil_image = self.numpy_to_pil(decoded_image) | |
return { | |
"sample": pil_image, | |
"score": scores | |
# "nsfw_content_detected": has_nsfw_concept, | |
} | |
# # check safety | |
# _, has_nsfw_concept = self.safety_checker( | |
# images=image, clip_input=image_tensors | |
# ) # TODO: check if this is correct | |
def latents_from_init_image( | |
self, | |
init_image: torch.FloatTensor, | |
prompt_strength: float, | |
offset: int, | |
num_inference_steps: int, | |
batch_size: int, | |
generator: Optional[torch.Generator], | |
) -> Tuple[torch.FloatTensor, torch.FloatTensor, int]: | |
# encode the init image into latents and scale the latents | |
init_latents = self.vae.encode(init_image.to(self.device)).sample() | |
init_latents = 0.18215 * init_latents | |
init_latents_orig = init_latents | |
# prepare init_latents noise to latents | |
init_latents = torch.cat([init_latents] * batch_size) | |
# get the original timestep using init_timestep | |
init_timestep = int(num_inference_steps * prompt_strength) + offset | |
init_timestep = min(init_timestep, num_inference_steps) | |
timesteps = self.scheduler.timesteps[-init_timestep] | |
timesteps = torch.tensor( | |
[timesteps] * batch_size, dtype=torch.long, device=self.device | |
) | |
# add noise to latents using the timesteps | |
noise = torch.randn(init_latents.shape, generator=generator, device=self.device) | |
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) | |
return init_latents_orig, init_latents, init_timestep | |
def embed_text( | |
self, | |
prompt: List[str], | |
pooled_output: bool = False, | |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: | |
"""Embed text using CLIP model.""" | |
assert isinstance( | |
prompt, list | |
), f"prompt {prompt} must be a list of strings, try wrapping it in a list e.g. [prompt]" | |
# get prompt text embeddings | |
text_input = self.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeds = self.text_encoder( | |
text_input.input_ids.to(self.device) | |
) # last_hiden_state (batch_size, seq_len, dim) | |
if pooled_output: | |
return text_embeds[1] | |
else: | |
return text_embeds[0] | |
def image_to_embedding(self, image_tensors: torch.FloatTensor) -> torch.FloatTensor: | |
pooled_image_embeddings = self.safety_checker.vision_model( | |
pixel_values=image_tensors | |
).pooler_output | |
pooled_image_embeddings = self.safety_checker.visual_projection( | |
pooled_image_embeddings | |
) # we have to project the pooled image embeddings to the same dimension as the text embeddings | |
return pooled_image_embeddings |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment