Skip to content

Instantly share code, notes, and snippets.

@Norod
Last active September 28, 2022 08:51
Show Gist options
  • Save Norod/83911b25ccf0cde368516cf5bbc131cf to your computer and use it in GitHub Desktop.
Save Norod/83911b25ccf0cde368516cf5bbc131cf to your computer and use it in GitHub Desktop.
Encapsulate an encoded VAE latent as PNG image, then load it and use it to decode the original image
#!pip install diffusers==0.2.4
import torch
from diffusers import AutoencoderKL
from PIL import Image
import numpy as np
from torchvision import transforms as tfms
torch_device = None
vae = None
to_tensor_tfm = None
def setup():
global torch_device
global vae
global to_tensor_tfm
# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=True)
# To the GPU we go!
vae = vae.to(torch_device)
# Using torchvision.transforms.ToTensor
to_tensor_tfm = tfms.ToTensor()
def pil_to_latent(input_im):
# Single image -> single latent in a batch (so size 1, 4, 64, 64)
with torch.no_grad():
latent = vae.encode(to_tensor_tfm(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
return 0.18215 * latent.sample() # or .mean or .sample
def latents_to_pil(latents):
# bath of latents -> list of images
latents = (1 / 0.18215) * latents
with torch.no_grad():
image = vae.decode(latents)
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def latents_as_images(latents):
np.save("homer_latent", latents.detach().cpu().permute(0, 2, 3, 1).numpy(), allow_pickle=False)
minValue = latents.min()
maxValue = latents.max()
latents = (latents-minValue)/(maxValue-minValue)
np.save("homer_latent_norm", latents.detach().cpu().permute(0, 2, 3, 1).numpy(), allow_pickle=False)
image = latents.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def encode(input_image_file):
# Load the image with PIL
input_image = Image.open(input_image_file).resize((512, 512))
encoded = pil_to_latent(input_image)
return encoded
def decode(encoded_latents):
decoded = latents_to_pil(encoded_latents)[0]
return decoded
def enocde_save_png():
print("Encode image")
encoded_latents = encode('homer.jpg')
print("encoded_latents.shape = " + str(encoded_latents.shape)) #encoded_latents.shape = torch.Size([1, 4, 64, 64])
encoded_latents_as_image = latents_as_images(encoded_latents)
encoded_latents_as_image[0].save('encoded_latents_as_image_homer.png') #Note: The alpha channel also contains information
encoded_latents_as_image[0].show()
print("Decode Image")
image = decode(encoded_latents)
image.save('decoded_image_from_latents_homer.png')
image.show()
def reduced_latents_from_png(png_image_name):
image_in = Image.open(png_image_name)
image_in = np.array(image_in, np.float32)
print(image_in.shape)
image_in = image_in/255.0
image_in = image_in.transpose((2, 0, 1))
image_out = np.expand_dims(image_in, 0)
reduced_latents = torch.tensor(image_out)
kMax = 3.9860
kMin = -2.7767
reduced_latents = (reduced_latents*(kMax-kMin))+kMin
print(reduced_latents.shape)
return reduced_latents
def load_png_decode():
print("Load reduced latents")
reduced_latents = reduced_latents_from_png('encoded_latents_as_image_homer.png')
print("Decode reduced latents")
image = decode(reduced_latents)
image.save('decoded_image_from_reduced_latents_homer.png')
image.show()
def main():
#!curl --output homer.jpg 'https://static.wikia.nocookie.net/peppapedia/images/f/f0/Homer_Simpson_concept_art.jpg/revision/latest?cb=20191102051909'
print("Load VAE")
setup()
enocde_save_png()
load_png_decode()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment