Last active
September 28, 2022 08:51
-
-
Save Norod/83911b25ccf0cde368516cf5bbc131cf to your computer and use it in GitHub Desktop.
Encapsulate an encoded VAE latent as PNG image, then load it and use it to decode the original image
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!pip install diffusers==0.2.4 | |
import torch | |
from diffusers import AutoencoderKL | |
from PIL import Image | |
import numpy as np | |
from torchvision import transforms as tfms | |
torch_device = None | |
vae = None | |
to_tensor_tfm = None | |
def setup(): | |
global torch_device | |
global vae | |
global to_tensor_tfm | |
# Set device | |
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the autoencoder model which will be used to decode the latents into image space. | |
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=True) | |
# To the GPU we go! | |
vae = vae.to(torch_device) | |
# Using torchvision.transforms.ToTensor | |
to_tensor_tfm = tfms.ToTensor() | |
def pil_to_latent(input_im): | |
# Single image -> single latent in a batch (so size 1, 4, 64, 64) | |
with torch.no_grad(): | |
latent = vae.encode(to_tensor_tfm(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling | |
return 0.18215 * latent.sample() # or .mean or .sample | |
def latents_to_pil(latents): | |
# bath of latents -> list of images | |
latents = (1 / 0.18215) * latents | |
with torch.no_grad(): | |
image = vae.decode(latents) | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
images = (image * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def latents_as_images(latents): | |
np.save("homer_latent", latents.detach().cpu().permute(0, 2, 3, 1).numpy(), allow_pickle=False) | |
minValue = latents.min() | |
maxValue = latents.max() | |
latents = (latents-minValue)/(maxValue-minValue) | |
np.save("homer_latent_norm", latents.detach().cpu().permute(0, 2, 3, 1).numpy(), allow_pickle=False) | |
image = latents.detach().cpu().permute(0, 2, 3, 1).numpy() | |
images = (image * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def encode(input_image_file): | |
# Load the image with PIL | |
input_image = Image.open(input_image_file).resize((512, 512)) | |
encoded = pil_to_latent(input_image) | |
return encoded | |
def decode(encoded_latents): | |
decoded = latents_to_pil(encoded_latents)[0] | |
return decoded | |
def enocde_save_png(): | |
print("Encode image") | |
encoded_latents = encode('homer.jpg') | |
print("encoded_latents.shape = " + str(encoded_latents.shape)) #encoded_latents.shape = torch.Size([1, 4, 64, 64]) | |
encoded_latents_as_image = latents_as_images(encoded_latents) | |
encoded_latents_as_image[0].save('encoded_latents_as_image_homer.png') #Note: The alpha channel also contains information | |
encoded_latents_as_image[0].show() | |
print("Decode Image") | |
image = decode(encoded_latents) | |
image.save('decoded_image_from_latents_homer.png') | |
image.show() | |
def reduced_latents_from_png(png_image_name): | |
image_in = Image.open(png_image_name) | |
image_in = np.array(image_in, np.float32) | |
print(image_in.shape) | |
image_in = image_in/255.0 | |
image_in = image_in.transpose((2, 0, 1)) | |
image_out = np.expand_dims(image_in, 0) | |
reduced_latents = torch.tensor(image_out) | |
kMax = 3.9860 | |
kMin = -2.7767 | |
reduced_latents = (reduced_latents*(kMax-kMin))+kMin | |
print(reduced_latents.shape) | |
return reduced_latents | |
def load_png_decode(): | |
print("Load reduced latents") | |
reduced_latents = reduced_latents_from_png('encoded_latents_as_image_homer.png') | |
print("Decode reduced latents") | |
image = decode(reduced_latents) | |
image.save('decoded_image_from_reduced_latents_homer.png') | |
image.show() | |
def main(): | |
#!curl --output homer.jpg 'https://static.wikia.nocookie.net/peppapedia/images/f/f0/Homer_Simpson_concept_art.jpg/revision/latest?cb=20191102051909' | |
print("Load VAE") | |
setup() | |
enocde_save_png() | |
load_png_decode() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment