Created
January 31, 2025 10:46
-
-
Save brayevalerien/4bfc791adc35cc8ac845c5dc2e785042 to your computer and use it in GitHub Desktop.
Automatically storts and resize/crop images into resolution buckets for diffusion models training.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
from PIL import Image | |
from PIL.Image import Resampling | |
def get_args(): | |
parser = argparse.ArgumentParser( | |
"Bucketize", | |
description="Automatically storts and resize/crop images into resolution buckets for diffusion models training.", | |
) | |
parser.add_argument( | |
"--input", | |
help="Path to a directory containing the image to bucketize.", | |
required=True, | |
) | |
parser.add_argument( | |
"--buckets", | |
help="Resolution buckets, in the format 'width1,height1|width2,height2|...' where width and height must be integers.", | |
required=False, | |
default="1024,1024|1344,768|928,1152", | |
) | |
parser.add_argument( | |
"--output", | |
help="Path to the output directory where all the bucketized images will be saved.", | |
required=False, | |
default="./output/", | |
) | |
return parser.parse_args() | |
def parse_buckets(raw: str) -> list[tuple[int, int]]: | |
""" | |
Parses a raw string describing resolution buckets into a usable list of resolutions. | |
For instance, "1024,1024|1344,768|928,1152" returns [(1024, 1024), (1344, 768), (928, 1152)] | |
""" | |
res = [] | |
for raw_bucket in raw.split("|"): | |
try: | |
res.append((int(raw_bucket.split(",")[0]), int(raw_bucket.split(",")[1]))) | |
except RuntimeError as e: | |
raise ValueError( | |
f"Cannot parse bucket {raw_bucket}, make sure you use the proper synthax: {e}" | |
) | |
return res | |
def resize_and_crop(image: Image.Image, width: int, height: int) -> Image.Image: | |
""" | |
Resizes an image while keeping its proportions and then crop it to match the given resolution. | |
Adapted from: https://github.com/comfyanonymous/ComfyUI/blob/8d8dc9a262bfee9f7eb3a2fc381b79a69274f225/comfy/utils.py#L825 | |
""" | |
old_width, old_height = image.size | |
old_aspect = old_width / old_height | |
new_aspect = width / height | |
x = 0 | |
y = 0 | |
if old_aspect > new_aspect: | |
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) | |
elif old_aspect < new_aspect: | |
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) | |
# s = image.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2) | |
s = image.crop((x, y, old_width - x, old_height - y)) | |
return s.resize((width, height), Resampling.LANCZOS) | |
def bucketize( | |
images: list[Image.Image], buckets: list[tuple[int, int]] | |
) -> list[Image.Image]: | |
def best_bucket(width: int, height: int) -> tuple[int, int]: | |
ratio = width / height | |
return min(buckets, key=lambda b: abs(ratio - b[0] / b[1])) | |
return [resize_and_crop(img, *best_bucket(*img.size)) for img in images] | |
def load_images(directory: str, recursive: bool = False) -> list[Image.Image]: | |
""" | |
Load all images from a given directory. Supports recursive search (meaning all subdirectories will be loaded too). | |
""" | |
image_files = [] | |
if recursive: | |
for root, _, files in os.walk(directory): | |
for file in files: | |
image_files.append(os.path.join(root, file)) | |
else: | |
image_files = [os.path.join(directory, file) for file in os.listdir(directory)] | |
images = [] | |
for file in image_files: | |
if file.endswith("png") or file.endswith("jpg") or file.endswith("jpeg"): | |
with Image.open(file) as image: | |
images.append( | |
image.copy() | |
) # Copy image to avoid issues with context manager | |
return images | |
def save_images(images: list[Image.Image], path: str): | |
""" | |
Save images in a given directory | |
""" | |
if not os.path.exists(path): | |
try: | |
os.makedirs(path) | |
except OSError as e: | |
raise OSError(f"Failed to create directory {path}: {e}") | |
for i, image in enumerate(images): | |
try: | |
filename = f"img{i:03d}.png" | |
filepath = os.path.join(path, filename) | |
image.save(filepath, format="PNG") | |
except Exception as e: | |
print(f"Error saving image {i}: {e}") | |
if __name__ == "__main__": | |
args = get_args() | |
input_path = args.input | |
output_path = args.output | |
buckets = parse_buckets(args.buckets) | |
if len(buckets) == 0: | |
raise ValueError("At least one bucket must be specified.") | |
images = load_images(input_path) | |
processed = bucketize(images, buckets) | |
save_images(processed, output_path) | |
print("Done.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment