Last active
November 7, 2023 20:16
-
-
Save rockerBOO/91b344497f3670070449b431289ba6c5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
from collections import defaultdict | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import tqdm | |
from PIL import Image | |
from torchvision import transforms | |
import library.model_util as model_util | |
import library.sdxl_train_util as sdxl_train_util | |
IMAGE_TRANSFORMS = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
def load_image(image_path): | |
image = Image.open(image_path) | |
if not image.mode == "RGB": | |
image = image.convert("RGB") | |
img = np.array(image, np.uint8) | |
return img, image.info | |
def process_images_group(vae, images_group): | |
with torch.no_grad(): | |
# Stack the tensors from the same size group | |
img_tensors = torch.stack(images_group, dim=0).to(vae.device) | |
# Encode and decode the images | |
latents = vae.encode(img_tensors).latent_dist.sample() | |
return latents | |
def process_latents_from_images(vae, input_file_or_dir, output_dir, batch_size=1): | |
if args.consistency_decoder: | |
from consistencydecoder import ConsistencyDecoder | |
decoder_consistency = ConsistencyDecoder(device=vae.device) | |
input = Path(input_file_or_dir) | |
output = Path(output_dir) | |
if input.is_dir(): | |
image_files = [ | |
file | |
for file in input.iterdir() | |
if file.suffix in ["jpg", "jpeg", "png", "webp", "bmp", "avif"] | |
] | |
else: | |
image_files = [input] | |
size_to_images = defaultdict(list) | |
file_names = [] # List to keep track of file names | |
for image_file in image_files: | |
# image_path = os.path.join(input_dir, image_file) | |
image, _ = load_image(image_file) | |
transformed_image = IMAGE_TRANSFORMS(image) | |
size_to_images[transformed_image.shape[1:]].append(transformed_image) | |
file_names.append(image_file) # Save the file name | |
# os.makedirs(output_dir, exist_ok=True) | |
total_images = len(file_names) | |
with tqdm.tqdm(total=total_images) as progress_bar: | |
for size, images_group in size_to_images.items(): | |
# Process images in batches | |
for i in range(0, len(images_group), batch_size): | |
batch = images_group[i : i + batch_size] | |
batch_file_names = file_names[i : i + batch_size] | |
# Get the batch file names | |
latents = process_images_group(vae, batch) | |
if args.consistency_decoder: | |
consistencydecoder_and_save( | |
decoder_consistency, | |
latents, | |
batch_file_names, | |
output, | |
device=vae.device, | |
) | |
else: | |
decode_vae_and_save(vae, latents, batch_file_names, output) | |
progress_bar.update(1) | |
def decode_vae_and_save(vae, latents, filenames, output): | |
with torch.no_grad(): | |
decoded_images = [] | |
for i in range(0, 1, 1): | |
decoded_images.append( | |
vae.decode( | |
latents[i : i + 1] if i > 1 else latents[i].unsqueeze(0) | |
).sample | |
) | |
decoded_images = torch.cat(decoded_images) | |
# Rescale images from [-1, 1] to [0, 255] and save | |
decoded_images = ( | |
((decoded_images / 2 + 0.5).clamp(0, 1) * 255) | |
.cpu() | |
.permute(0, 2, 3, 1) | |
.numpy() | |
.astype("uint8") | |
) | |
for j, decoded_image in enumerate(decoded_images): | |
original_file = filenames[j] # Get the original file name for each image | |
output_file = ( | |
output.absolute() | |
/ original_file.with_name(f"{original_file.stem}-latents-decoded.png").name | |
) | |
output_image = Image.fromarray(decoded_image) | |
output_image.save(output_file) | |
output_gif_file = ( | |
output.absolute() | |
/ original_file.with_name(f"{original_file.stem}-latents-decoded.gif").name | |
) | |
Image.open(original_file).save( | |
output_gif_file, | |
save_all=True, | |
append_images=[output_image], | |
duration=500, | |
loop=0, | |
) | |
def consistencydecoder_and_save( | |
decoder_consistency, latents, filenames, output_dir, device | |
): | |
from consistencydecoder import save_image | |
with torch.no_grad(): | |
sample_consistences = decoder_consistency(latents) | |
for j, decoded_image in enumerate(sample_consistences): | |
original_file_name = filenames[ | |
j | |
] # Get the original file name for each image | |
original_name_without_extension = os.path.splitext(original_file_name)[0] | |
save_image( | |
decoded_image, | |
os.path.join( | |
output_dir, | |
f"{original_name_without_extension}-latents-decoded-consistency.png", | |
), | |
) | |
def main(args): | |
device = torch.device(args.device) | |
if args.vae is None: | |
if args.sdxl: | |
# putting this in here just to be able to pass the argument | |
from accelerate import Accelerator | |
accelerator = Accelerator() | |
_, _, _, vae, _, _, _ = sdxl_train_util.load_target_model( | |
args, | |
accelerator, | |
args.pretrained_model_name_or_path, | |
torch.float16, | |
) | |
else: | |
# Load model's VAE | |
_, vae, _ = model_util.load_models_from_stable_diffusion_checkpoint( | |
args.v2, | |
args.pretrained_model_name_or_path, | |
) | |
vae.to(device, dtype=torch.float32) | |
else: | |
vae = model_util.load_vae(args.vae, torch.float32).to(device) | |
# Save image decoded latents | |
process_latents_from_images( | |
vae, args.input_file_or_dir, args.output_dir, args.batch_size | |
) | |
if __name__ == "__main__": | |
argparser = argparse.ArgumentParser() | |
argparser.add_argument("--device", default="cpu") | |
argparser.add_argument( | |
"--input_file_or_dir", help="Input file or directory to load the images from" | |
) | |
argparser.add_argument( | |
"--output_dir", help="Output directory to put the VAE decoded images" | |
) | |
argparser.add_argument( | |
"--vae", default="", help="Path to VAE file or hugging face VAE path" | |
) | |
argparser.add_argument( | |
"--pretrained_model_name_or_path", | |
default="", | |
help="Stable diffusion model name or path to load the VAE from.", | |
) | |
argparser.add_argument( | |
"--v2", action="store_true", help="Is a Stable Diffusion v2 model." | |
) | |
argparser.add_argument( | |
"--batch_size", type=int, default=1, help="Batch size to process the images." | |
) | |
argparser.add_argument( | |
"--sdxl", action="store_true", help="(NOTWORKING) SDXL model" | |
) | |
argparser.add_argument("--lowram", type=int, default=1, help="SDXL low ram option") | |
argparser.add_argument( | |
"--full_fp16", type=int, default=1, help="SDXL use full fp16" | |
) | |
argparser.add_argument( | |
"--full_bf16", type=int, default=1, help="SDXL use full bf16" | |
) | |
argparser.add_argument( | |
"--consistency_decoder", | |
action="store_true", | |
help="Use Consistency Decoder from OpenAI https://github.com/openai/consistencydecoder", | |
) | |
args = argparser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment