Created
March 12, 2024 11:16
-
-
Save rovo79/09405fea787c05986fa7aad16995fb85 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from turtle import forward | |
from pyparsing import Forward | |
import torch | |
from PIL import Image | |
import numpy as np | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler | |
# from diffusers import UNet2DModel | |
repo_id = "/Volumes/Acasis1TB/machine_learning/stable-diffusion-xl-base-1.0" | |
torch_device = "mps" | |
# model = UNet2DConditionModel.from_pretrained( | |
# repo_id, subfolder="unet", use_safetensors=True | |
# ) | |
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") | |
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", use_safetensors=True) | |
tokenizer = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained( | |
repo_id, subfolder="text_encoder", use_safetensors=True | |
) | |
model = UNet2DConditionModel.from_pretrained( | |
repo_id, subfolder="unet", use_safetensors=True | |
) | |
vae.to(torch_device) | |
text_encoder.to(torch_device) | |
model.to(torch_device) | |
# print(model.config) | |
# model = attributes = dir(model) | |
# print(attributes) | |
# print(dir(model.forward)) | |
# help(model.forward) | |
prompt = ["a photograph of an astronaut riding a horse"] | |
height = 1024 # default height of Stable Diffusion | |
width = 1024 # default width of Stable Diffusion | |
num_inference_steps = 15 # Number of denoising steps | |
guidance_scale = 7.5 # Scale for classifier-free guidance | |
generator = torch.mps.manual_seed( | |
0 | |
) # Seed generator to create the initial latent noise | |
batch_size = len(prompt) | |
text_input = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
# for name, param in model.named_parameters(): | |
# if "some_keyword" in name: # Adjust the keyword | |
# print(name, param) | |
with torch.no_grad(): | |
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
# print(text_embeddings, text_embeddings.shape) | |
# added_cond_kwargs = {"text_embeds": text_embeddings} | |
# print("added_cond_kwargs (before model call):", added_cond_kwargs) | |
latents = torch.randn( | |
(batch_size, model.config.in_channels, height // 8, width // 8), | |
generator=generator, | |
device=torch_device, | |
) | |
latents = latents * scheduler.init_noise_sigma | |
from tqdm.auto import tqdm | |
scheduler.set_timesteps(num_inference_steps) | |
for t in tqdm(scheduler.timesteps): | |
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t) | |
time_ids = torch.tensor( | |
[t.item()] * batch_size, device=torch_device, dtype=torch.float | |
) # Change to dtype=torch.float | |
time_embeds = model.time_embedding(time_ids).unsqueeze(1) # Added unsqueeze | |
print(time_embeds.shape) | |
added_cond_kwargs = {"text_embeds": text_embeddings, "time_embeds": time_embeds} | |
print("added_cond_kwargs (before model call):", added_cond_kwargs.shape) | |
# predict the noise residual | |
with torch.no_grad(): | |
noise_pred = model( | |
latent_model_input, | |
t, | |
encoder_hidden_states=text_embeddings, | |
added_cond_kwargs=added_cond_kwargs, | |
).sample | |
# perform guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = scheduler.step(noise_pred, t, latents).prev_sample | |
# scale and decode the image latents with vae | |
latents = 1 / 0.18215 * latents | |
with torch.no_grad(): | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1).squeeze() | |
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).numpy() | |
images = (image * 255).round().astype("uint8") | |
image = Image.fromarray(image) | |
image.save("generated_image.png") | |
# sample (torch.FloatTensor) — The noisy input tensor with the following shape (batch, channel, height, width). | |
# sample = [batch, channel, height, width] | |
# timestep (torch.FloatTensor or float or int) — The number of timesteps to denoise an input. | |
# encoder_hidden_states (torch.FloatTensor) — The encoder hidden states with shape (batch, sequence_length, feature_dim). | |
# image = model(sample, timestep, encoder_hidden_states).images[0] | |
# image = model(num_inference_steps=20).images[0] | |
# image.save("generated_image.png") | |
# mlmodel = ct.convert(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment