Created
December 17, 2023 13:02
-
-
Save btlorch/5ddea5e6a0951995d09536fbc95a3dfd to your computer and use it in GitHub Desktop.
Compress images with the SDXL auto-encoder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from PIL import Image | |
import numpy as np | |
from glob import glob | |
import os | |
from tqdm import tqdm | |
from diffusers import DiffusionPipeline | |
import torch | |
import torchvision.transforms as T | |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
def prepare_image(filepath): | |
# Read PIL image | |
img = Image.open(filepath) | |
# Convert to tensor in range [0, 1] | |
img_torch = T.ToTensor()(img).to(DEVICE) | |
# Add singleton batch dimension | |
img_batch = torch.unsqueeze(img_torch, dim=0) | |
# Normalize to range [-1, 1] | |
img_batch = img_batch * 2. - 1. | |
return img_batch | |
def tensor_to_pil(x): | |
x = x.detach().cpu() | |
# Clip to range [-1, 1] | |
x = torch.clamp(x, -1., 1.) | |
# Scale t orange [0, 1] | |
x = (x + 1.) / 2. | |
# Move channel axis to the end | |
x = x.permute(1, 2, 0).numpy() | |
# Scale to uint8 range | |
x = (255 * x).astype(np.uint8) | |
# Convert to PIL image | |
x = Image.fromarray(x) | |
if not x.mode == "RGB": | |
x = x.convert("RGB") | |
return x | |
def load_vae(): | |
# Load the SDXL base model | |
base = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float32, variant="fp16", use_safetensors=True) | |
base.to(DEVICE) | |
# We are only interested in the auto-encoder | |
vae = base.vae | |
return vae | |
def compress_decompress(filepaths, output_dir, vae): | |
for filepath in tqdm(filepaths): | |
output_filepath = os.path.join(output_dir, os.path.splitext(os.path.basename(filepath))[0] + ".png") | |
if os.path.exists(output_filepath): | |
print(f"Skipping because output file \"{output_filepath}\" already exists") | |
continue | |
# Load image and convert to range [-1, +1] | |
img_batch = prepare_image(filepath) | |
# Feed through auto-encoder | |
img_reconstructed_batch = vae(img_batch)["sample"] | |
# Convert back to PIL image | |
img_reconstructed = tensor_to_pil(img_reconstructed_batch[0]) | |
img_reconstructed.save(output_filepath) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input_dir", type=str, help="Path to input directory", default="ALASKA_v2_TIFF_512_COLOR") | |
parser.add_argument("--output_dir", type=str, help="Path to output directory", default="/tmp") | |
parser.add_argument("--max_num_images", type=int, help="Take only a limited number of samples") | |
args = vars(parser.parse_args()) | |
filepaths = sorted(glob(os.path.join(args["input_dir"], "*.tif"))) | |
if args["max_num_images"]: | |
filepaths = filepaths[:args["max_num_images"]] | |
vae = load_vae() | |
compress_decompress(filepaths, args["output_dir"], vae=vae) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment