Skip to content

Instantly share code, notes, and snippets.

@recoilme
Created December 9, 2024 12:23
Show Gist options
  • Save recoilme/46dfe328ef19eb5484f489ef4f18c92b to your computer and use it in GitHub Desktop.
Save recoilme/46dfe328ef19eb5484f489ef4f18c92b to your computer and use it in GitHub Desktop.
AuraDiffusionVae.py
from diffusers import AutoencoderTiny, AutoencoderKL
import torch
import torchvision.transforms.functional as TF
from PIL import Image
#taesd3 = AutoencoderTiny.from_pretrained(
# "madebyollin/taesd3", torch_dtype=torch.float16
#).to("cuda")
sd3_vae = AutoencoderKL.from_pretrained(
"AuraDiffusion/16ch-vae",
#"stabilityai/stable-diffusion-3-medium-diffusers",
torch_dtype=torch.float32,
#subfolder="vae",
).to("cpu")
def demo_sd3_encode_and_taesd_decode(image: Image):
# load raw image onto the device with values in in [0, 1] range
image_raw = TF.to_tensor(test_image).unsqueeze(0).to(torch.float32).to("cpu")
# scale to the [-1, 1] value range expected by diffusers VAEs, encode with the SD3 VAE, and manually apply the SD3 VAE scaling factors
image_enc = (
sd3_vae.encode(image_raw)#.mul(2).sub_(1))
.latent_dist.sample()
.sub_(sd3_vae.config.shift_factor)
.mul_(sd3_vae.config.scaling_factor)
)
display(TF.to_pil_image(image_raw[0]))
#print("SD3-encoded latents", summarize_tensor(image_enc[0]))
display(TF.to_pil_image(image_enc[0, :3].mul(0.3).add(0.5).clamp(0, 1)))
#print("TAESD3-decoded image", summarize_tensor(image_dec[0]))
image_dec_sd = sd3_vae.decode(image_enc / sd3_vae.config.scaling_factor)[0]
g3 = TF.to_pil_image(torch.cat([image_dec_sd[0]], -1).clamp(0, 1))
display(g3)
g3.save("g3.png")
# !wget -nc -q https://lafeber.com/pet-birds/wp-content/uploads/2018/06/Scarlet-Macaw-2.jpg -O test_image.jpg
test_image = TF.resize(Image.open("g1.png").convert("RGB"),512)
test_image.save("g0.png")
demo_sd3_encode_and_taesd_decode(test_image)
print("1")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment