Last active
March 8, 2025 19:40
-
-
Save mikkohei13/5d3538f5b8f0797302dd7334646e87ea to your computer and use it in GitHub Desktop.
Image augmentation for machine learning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Image augmentation for machine learning | |
# Script that loops through images in subdirectories and replaces the background using rembg and does transformations to selected number of images in each directory | |
import rembg | |
from pathlib import Path | |
import random | |
from PIL import Image, ImageEnhance, ImageFilter | |
import numpy as np | |
import gc | |
import time | |
# Global variable to track the current background index | |
current_bg_index = 0 | |
def get_next_background(): | |
"""Get the next background image from the list and increment the index""" | |
global current_bg_index | |
# Load the background image | |
try: | |
bg_path = BGR_IMAGE_PATHS[current_bg_index] | |
background = Image.open(bg_path) | |
print(f"Using background: {bg_path}") | |
# Increment the index for next time, wrapping around if needed | |
current_bg_index = (current_bg_index + 1) % len(BGR_IMAGE_PATHS) | |
return background | |
except Exception as e: | |
print(f"Failed to load background image: {bg_path}") | |
print(f"Error: {e}") | |
exit("Stopping") | |
def process_image(image_path): | |
print(f"Processing image: {image_path}") | |
# Skip if the image filename already contains the SUFFIX | |
if SUFFIX in image_path.stem: | |
print(f"Skipping already processed image: {image_path}") | |
return False | |
"""Remove background from image and replace with specified background""" | |
try: | |
# Read the image with PIL | |
input_image = Image.open(str(image_path)) | |
if input_image is None: | |
print(f"Failed to read image: {image_path}") | |
return False | |
# Resize large images to reduce memory usage | |
max_dimension = 1000 # Set a reasonable maximum dimension | |
if input_image.width > max_dimension or input_image.height > max_dimension: | |
# Calculate new dimensions while preserving aspect ratio | |
if input_image.width > input_image.height: | |
new_width = max_dimension | |
new_height = int(input_image.height * (max_dimension / input_image.width)) | |
else: | |
new_height = max_dimension | |
new_width = int(input_image.width * (max_dimension / input_image.height)) | |
print(f"Resizing image from {input_image.width}x{input_image.height} to {new_width}x{new_height}") | |
input_image = input_image.resize((new_width, new_height), Image.LANCZOS) | |
# Remove background with alpha matting - adjusted parameters to avoid performance warning | |
output_transparent = rembg.remove( | |
input_image, | |
session=model, | |
alpha_matting=True, | |
alpha_matting_foreground_threshold=230, # Slightly reduced from 240 | |
alpha_matting_background_threshold=20, # Slightly increased from 10 | |
alpha_matting_erode_size=5, # Reduced from 10 | |
post_process_mask=True # Add post-processing to improve results | |
) | |
# Explicitly delete the input image to free memory | |
del input_image | |
# Get the next background image | |
background_image = get_next_background() | |
# Resize background to match the output image size | |
bg_resized = background_image.resize(output_transparent.size) | |
# Delete the original background image to free memory | |
del background_image | |
# Composite the transparent output onto the background | |
output = Image.new("RGBA", output_transparent.size) | |
output.paste(bg_resized, (0, 0)) | |
output.paste(output_transparent, (0, 0), output_transparent) | |
# Delete intermediate images to free memory | |
del bg_resized | |
del output_transparent | |
# Flip the image horizontally or vertically | |
if random.random() < 0.5: | |
output = output.transpose(Image.FLIP_LEFT_RIGHT) | |
else: | |
output = output.transpose(Image.FLIP_TOP_BOTTOM) | |
# Rotate the image 90, 180, or 270 degrees | |
# Note that this can create black bars on the sides of the image. use only if flipping is not enough. | |
''' | |
if random.random() < 0.33: | |
output = output.rotate(90) | |
elif random.random() < 0.66: | |
output = output.rotate(180) | |
else: | |
output = output.rotate(270) | |
''' | |
# Increase brightness by 0-20 % | |
brightness_factor = random.uniform(1, 1.2) | |
output = ImageEnhance.Brightness(output).enhance(brightness_factor) | |
# Zoom in by 0-20 % | |
zoom_factor = random.uniform(1, 1.2) | |
output = output.resize((int(output.width * zoom_factor), int(output.height * zoom_factor))) | |
# Add conservative color jittering | |
# Contrast: 0.9-1.1 (10% variation) | |
contrast_factor = random.uniform(0.9, 1.1) | |
output = ImageEnhance.Contrast(output).enhance(contrast_factor) | |
# Saturation: 0.9-1.1 (10% variation) | |
saturation_factor = random.uniform(0.9, 1.1) | |
output = ImageEnhance.Color(output).enhance(saturation_factor) | |
# Add subtle Gaussian noise (0.5-1.5% of pixel range) | |
if random.random() < 0.5: | |
output_array = np.array(output) | |
noise_level = random.uniform(1, 5) # Very conservative noise level | |
noise = np.random.normal(0, noise_level, output_array.shape) | |
noisy_array = np.clip(output_array + noise, 0, 255).astype(np.uint8) | |
output = Image.fromarray(noisy_array) | |
# Add subtle blur | |
if random.random() < 0.5: | |
blur_radius = random.uniform(0.2, 0.5) # Small blur radius for subtle effect | |
output = output.filter(ImageFilter.GaussianBlur(radius=blur_radius)) | |
# Generate output filename | |
stem = Path(image_path).stem | |
suffix = Path(image_path).suffix | |
output_path = Path(image_path).parent / f"{stem}_{SUFFIX}{suffix}" | |
# Save the processed image | |
output = output.convert("RGB") # Convert to RGB for saving as JPG | |
output.save(str(output_path)) | |
# Delete the output image to free memory | |
del output | |
# Force garbage collection | |
gc.collect() | |
print(f"Processed: {image_path} -> {output_path}") | |
# Add a small delay to allow memory to be freed | |
time.sleep(0.1) | |
return True | |
except Exception as e: | |
print(f"Error processing {image_path}: {e}") | |
# Force garbage collection | |
gc.collect() | |
return False | |
def process_directory(directory): | |
print(f"Processing directory: {directory}") | |
# Get all image files in the directory | |
image_files = [] | |
for ext in ['.jpg', '.jpeg', '.png']: | |
image_files.extend(list(directory.glob(f"*{ext.lower()}"))) | |
image_files.extend(list(directory.glob(f"*{ext.upper()}"))) | |
if not image_files: | |
print(f"No image files found in {directory}") | |
else: | |
# Find all processed images (with SUFFIX) | |
processed_images = [img for img in image_files if SUFFIX in img.stem] | |
processed_count = len(processed_images) | |
print(f"Found {processed_count} already processed images in {directory}") | |
# Extract the base names of source images that have already been augmented | |
# For example, if we have "image1_augmentedbgr.jpg", the source is "image1.jpg" | |
already_augmented_sources = set() | |
for processed_img in processed_images: | |
# Remove the SUFFIX from the stem to get the original source name | |
original_stem = processed_img.stem.replace(f"_{SUFFIX}", "") | |
already_augmented_sources.add(original_stem) | |
print(f"Found {len(already_augmented_sources)} source images that have already been augmented") | |
# Calculate how many more images to process | |
images_to_process = max(0, IMAGES_TO_AUGMENT - processed_count) | |
if images_to_process <= 0: | |
print(f"Already have {processed_count} processed images in {directory}, which meets or exceeds the target of {IMAGES_TO_AUGMENT}") | |
else: | |
print(f"Need to process {images_to_process} more images to reach the target of {IMAGES_TO_AUGMENT}") | |
# Filter out images that already contain the SUFFIX or have already been augmented | |
original_images = [] | |
for img in image_files: | |
if SUFFIX not in img.stem and img.stem not in already_augmented_sources: | |
original_images.append(img) | |
if not original_images: | |
print(f"No unprocessed source images found in {directory}") | |
else: | |
print(f"Found {len(original_images)} unprocessed source images") | |
# Randomly select images_to_process images if there are more than needed | |
if len(original_images) > images_to_process: | |
selected_images = random.sample(original_images, images_to_process) | |
else: | |
selected_images = original_images | |
# Process each selected image | |
for image_path in selected_images: | |
process_image(image_path) | |
# Force garbage collection after each image | |
gc.collect() | |
# Process all subdirectories recursively | |
for subdir in directory.iterdir(): | |
if subdir.is_dir(): | |
process_directory(subdir) | |
# Settings: | |
# Number of images to process in each directory | |
IMAGES_TO_AUGMENT = 3 | |
# Suffix for the output images | |
SUFFIX = "augmentedbgr" | |
# Input directory which can have subdirectories | |
INPUT_DIR = "input" | |
# Background image paths | |
BGR_IMAGE_PATHS = [ | |
"./bgrs/cardboard_blur.jpg", | |
"./bgrs/fabric_black_blur.jpg", | |
"./bgrs/fabric_blur.jpg", | |
"./bgrs/grey_blur.jpg", | |
"./bgrs/mm_blur.jpg", | |
"./bgrs/paper_blur.jpg", | |
"./bgrs/plastic_blur.jpg", | |
"./bgrs/tatar_blur.jpg", | |
"./bgrs/veronica_blur.jpg", | |
"./bgrs/white.jpg", | |
"./bgrs/grey.jpg" | |
] | |
# Initialize the rembg model | |
model = rembg.new_session(model_name="isnet-general-use") | |
# Check if input directory exists | |
input_path = Path(INPUT_DIR) | |
if not input_path.exists(): | |
print(f"Input directory '{INPUT_DIR}' does not exist.") | |
exit("Stopping") | |
# Start recursive processing from the input directory | |
try: | |
process_directory(input_path) | |
print("Done") | |
except Exception as e: | |
print(f"Error during processing: {e}") | |
# Force garbage collection before exiting | |
gc.collect() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment