Created
August 23, 2022 10:23
-
-
Save afiaka87/26018065695e98e0e9ac576fa2e0a065 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import List | |
import numpy as np | |
import torch | |
from cog import BasePredictor, Input, Path | |
from diffusers import ( | |
AutoencoderKL, | |
LMSDiscreteScheduler, | |
UNet2DConditionModel, | |
) | |
from PIL import Image | |
from torchvision.transforms import functional as TF | |
from tqdm.auto import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer | |
class Predictor(BasePredictor): | |
def setup(self): | |
"""Load the model into memory to make running multiple predictions efficient""" | |
self.output_dir = Path("cog_output") # TODO | |
self.output_dir.mkdir(exist_ok=True) | |
cache_dir = "model_cache" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
# the autoencoder compresses images into a more compact representation, and then decodes them back into pixels. | |
self.vae = AutoencoderKL.from_pretrained( | |
cache_dir, | |
subfolder="vae", | |
) | |
self.vae.to(self.device) | |
print("Loaded autoencoder") | |
# Words are first tokenized into integers that CLIP understands. | |
# These tokens are then passed to the text encoder to produce a vector representation (768 floating point numbers) of the text. | |
self.tokenizer = CLIPTokenizer.from_pretrained( | |
cache_dir, | |
subfolder="tokenizer", | |
) # tokenizer is small enough for CPU | |
self.text_encoder = CLIPTextModel.from_pretrained( | |
cache_dir, | |
subfolder="text_encoder", | |
) | |
self.text_encoder.to(self.device) | |
print("Loaded CLIP text encoder.") | |
# To generate latents from the text, a denoising diffusion UNet model is used. | |
# This model is trained to generate autoencoder latents from text. | |
# The autoencoder can then be used to decode the latents into image space. | |
self.unet = UNet2DConditionModel.from_pretrained( | |
"model_cache", | |
subfolder="unet", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
) | |
self.unet.to(self.device) | |
print("Loaded unet") | |
self.scheduler = LMSDiscreteScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
num_train_timesteps=1000, | |
) | |
@torch.inference_mode() # disables dropout, autograd, etc. | |
@torch.cuda.amp.autocast() # automatically casts to fp16 | |
def predict( | |
self, | |
prompt: str = Input(description="Input prompt", default=""), | |
image_prompt: Path = Input(description="Input image prompt", default=None), | |
num_outputs: int = Input( | |
description="Number of images to output", choices=[1, 4, 16], default=1 | |
), | |
num_inference_steps: int = Input( | |
description="Number of denoising steps", ge=1, le=500, default=100 | |
), | |
guidance_scale: float = Input( | |
description="Scale for classifier-free guidance", ge=1, le=20, default=7.5 | |
), | |
height: int = Input( | |
description="Height of output images", | |
default=512, | |
choices=[256, 384, 512, 640, 768, 1024], | |
), | |
width: int = Input( | |
description="Width of output images", | |
default=512, | |
choices=[256, 384, 512, 640, 768, 1024], | |
), | |
seed: int = Input( | |
description="Random seed. Leave blank to randomize the seed", default=None | |
), | |
image_prompt_strength: float = Input( | |
description="Strength of image prompt", ge=0, le=1, default=0.5 | |
), | |
) -> List[Path]: | |
"""Run a single prediction on the model""" | |
if seed is None: | |
seed = int.from_bytes(os.urandom(2), "big") | |
generator = torch.manual_seed(seed) | |
print(f"Using seed: {seed}") | |
self.scheduler.set_timesteps(num_inference_steps) | |
prompt = [prompt] * num_outputs | |
text_input = self.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = self.tokenizer( | |
[""] * num_outputs, | |
padding="max_length", | |
max_length=max_length, | |
return_tensors="pt", | |
) | |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
text_embeddings = torch.cat( | |
[ | |
uncond_embeddings.to(self.device), | |
text_embeddings.to(self.device), | |
] | |
) | |
latents = torch.randn( | |
(num_outputs, self.unet.in_channels, height // 8, width // 8), | |
generator=generator, | |
) | |
if image_prompt is not None: | |
init_image = Image.open(image_prompt).convert("RGB") | |
init_image = init_image.resize((int(width), int(height)), Image.LANCZOS) | |
init_image = ( | |
TF.to_tensor(init_image).to(self.device).unsqueeze(0).clamp(0, 1) | |
) | |
latents = ( | |
self.vae.encode(init_image.to(self.device) * 2 - 1).sample() * 0.18215 | |
) | |
latents = latents.to(self.device) | |
# TODO - need to add noise to the latents, for the correct timesteps, then start the scheduler halfway through somehow. | |
# yikes. | |
print(f"Using {num_inference_steps} inference steps") | |
latents = latents * self.scheduler.sigmas[0] | |
# generate latents from the noise | |
for i, t in tqdm(enumerate(self.scheduler.timesteps)): | |
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
latent_model_input = torch.cat([latents] * 2) | |
sigma = self.scheduler.sigmas[i] | |
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | |
# predict the noise residual | |
noise_pred = self.unet( | |
latent_model_input, t, encoder_hidden_states=text_embeddings | |
)["sample"] | |
# perform guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, i, latents)["prev_sample"] | |
# scale and decode the image latents with vae | |
latents = 1 / 0.18215 * latents | |
images = self.vae.decode(latents) | |
# save the images | |
images = (images / 2 + 0.5).clamp(0, 1) | |
images = images.detach().cpu().permute(0, 2, 3, 1).numpy() | |
images = (images * 255).round().astype("uint8") | |
prediction_paths = [] | |
for idx, image in enumerate(images): | |
image = Image.fromarray(image) | |
image.save(self.output_dir / f"{idx:03d}.png") | |
prediction_paths.append(self.output_dir / f"{idx:03d}.png") | |
return prediction_paths |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment