Created
March 12, 2024 00:49
-
-
Save rovo79/6308f980747f2155927aa4f010d8d396 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from PIL import Image | |
import numpy as np | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler | |
# from diffusers import UNet2DModel | |
repo_id = "/Volumes/Acasis1TB/machine_learning/stable-diffusion-xl-base-1.0" | |
torch_device = "mps" | |
# model = UNet2DConditionModel.from_pretrained( | |
# repo_id, subfolder="unet", use_safetensors=True | |
# ) | |
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") | |
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", use_safetensors=True) | |
tokenizer = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained( | |
repo_id, subfolder="text_encoder", use_safetensors=True | |
) | |
model = UNet2DConditionModel.from_pretrained( | |
repo_id, subfolder="unet", use_safetensors=True | |
) | |
vae.to(torch_device) | |
text_encoder.to(torch_device) | |
model.to(torch_device) | |
print(model.config) | |
prompt = ["a photograph of an astronaut riding a horse"] | |
height = 1024 # default height of Stable Diffusion | |
width = 1024 # default width of Stable Diffusion | |
num_inference_steps = 15 # Number of denoising steps | |
guidance_scale = 7.5 # Scale for classifier-free guidance | |
generator = torch.mps.manual_seed( | |
0 | |
) # Seed generator to create the initial latent noise | |
batch_size = len(prompt) | |
text_input = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
with torch.no_grad(): | |
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
latents = torch.randn( | |
(batch_size, model.config.in_channels, height // 8, width // 8), | |
generator=generator, | |
device=torch_device, | |
) | |
latents = latents * scheduler.init_noise_sigma | |
from tqdm.auto import tqdm | |
print(latents) | |
scheduler.set_timesteps(num_inference_steps) | |
for t in tqdm(scheduler.timesteps): | |
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t) | |
# predict the noise residual | |
with torch.no_grad(): | |
noise_pred = model( | |
latent_model_input, t, encoder_hidden_states=text_embeddings | |
).sample | |
# perform guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = scheduler.step(noise_pred, t, latents).prev_sample | |
# scale and decode the image latents with vae | |
latents = 1 / 0.18215 * latents | |
with torch.no_grad(): | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1).squeeze() | |
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).numpy() | |
images = (image * 255).round().astype("uint8") | |
image = Image.fromarray(image) | |
image.save("generated_image.png") | |
# sample (torch.FloatTensor) — The noisy input tensor with the following shape (batch, channel, height, width). | |
# sample = [batch, channel, height, width] | |
# timestep (torch.FloatTensor or float or int) — The number of timesteps to denoise an input. | |
# encoder_hidden_states (torch.FloatTensor) — The encoder hidden states with shape (batch, sequence_length, feature_dim). | |
# image = model(sample, timestep, encoder_hidden_states).images[0] | |
# image = model(num_inference_steps=20).images[0] | |
# image.save("generated_image.png") | |
# mlmodel = ct.convert(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment