Skip to content

Instantly share code, notes, and snippets.

@brayevalerien
Created January 31, 2025 10:46
Show Gist options
  • Save brayevalerien/4bfc791adc35cc8ac845c5dc2e785042 to your computer and use it in GitHub Desktop.
Save brayevalerien/4bfc791adc35cc8ac845c5dc2e785042 to your computer and use it in GitHub Desktop.
Automatically storts and resize/crop images into resolution buckets for diffusion models training.
import argparse
import os
from PIL import Image
from PIL.Image import Resampling
def get_args():
parser = argparse.ArgumentParser(
"Bucketize",
description="Automatically storts and resize/crop images into resolution buckets for diffusion models training.",
)
parser.add_argument(
"--input",
help="Path to a directory containing the image to bucketize.",
required=True,
)
parser.add_argument(
"--buckets",
help="Resolution buckets, in the format 'width1,height1|width2,height2|...' where width and height must be integers.",
required=False,
default="1024,1024|1344,768|928,1152",
)
parser.add_argument(
"--output",
help="Path to the output directory where all the bucketized images will be saved.",
required=False,
default="./output/",
)
return parser.parse_args()
def parse_buckets(raw: str) -> list[tuple[int, int]]:
"""
Parses a raw string describing resolution buckets into a usable list of resolutions.
For instance, "1024,1024|1344,768|928,1152" returns [(1024, 1024), (1344, 768), (928, 1152)]
"""
res = []
for raw_bucket in raw.split("|"):
try:
res.append((int(raw_bucket.split(",")[0]), int(raw_bucket.split(",")[1])))
except RuntimeError as e:
raise ValueError(
f"Cannot parse bucket {raw_bucket}, make sure you use the proper synthax: {e}"
)
return res
def resize_and_crop(image: Image.Image, width: int, height: int) -> Image.Image:
"""
Resizes an image while keeping its proportions and then crop it to match the given resolution.
Adapted from: https://github.com/comfyanonymous/ComfyUI/blob/8d8dc9a262bfee9f7eb3a2fc381b79a69274f225/comfy/utils.py#L825
"""
old_width, old_height = image.size
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
# s = image.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
s = image.crop((x, y, old_width - x, old_height - y))
return s.resize((width, height), Resampling.LANCZOS)
def bucketize(
images: list[Image.Image], buckets: list[tuple[int, int]]
) -> list[Image.Image]:
def best_bucket(width: int, height: int) -> tuple[int, int]:
ratio = width / height
return min(buckets, key=lambda b: abs(ratio - b[0] / b[1]))
return [resize_and_crop(img, *best_bucket(*img.size)) for img in images]
def load_images(directory: str, recursive: bool = False) -> list[Image.Image]:
"""
Load all images from a given directory. Supports recursive search (meaning all subdirectories will be loaded too).
"""
image_files = []
if recursive:
for root, _, files in os.walk(directory):
for file in files:
image_files.append(os.path.join(root, file))
else:
image_files = [os.path.join(directory, file) for file in os.listdir(directory)]
images = []
for file in image_files:
if file.endswith("png") or file.endswith("jpg") or file.endswith("jpeg"):
with Image.open(file) as image:
images.append(
image.copy()
) # Copy image to avoid issues with context manager
return images
def save_images(images: list[Image.Image], path: str):
"""
Save images in a given directory
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
raise OSError(f"Failed to create directory {path}: {e}")
for i, image in enumerate(images):
try:
filename = f"img{i:03d}.png"
filepath = os.path.join(path, filename)
image.save(filepath, format="PNG")
except Exception as e:
print(f"Error saving image {i}: {e}")
if __name__ == "__main__":
args = get_args()
input_path = args.input
output_path = args.output
buckets = parse_buckets(args.buckets)
if len(buckets) == 0:
raise ValueError("At least one bucket must be specified.")
images = load_images(input_path)
processed = bucketize(images, buckets)
save_images(processed, output_path)
print("Done.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment