Skip to content

Instantly share code, notes, and snippets.

@mikkohei13
Last active March 8, 2025 19:40
Show Gist options
  • Save mikkohei13/5d3538f5b8f0797302dd7334646e87ea to your computer and use it in GitHub Desktop.
Save mikkohei13/5d3538f5b8f0797302dd7334646e87ea to your computer and use it in GitHub Desktop.
Image augmentation for machine learning
# Image augmentation for machine learning
# Script that loops through images in subdirectories and replaces the background using rembg and does transformations to selected number of images in each directory
import rembg
from pathlib import Path
import random
from PIL import Image, ImageEnhance, ImageFilter
import numpy as np
import gc
import time
# Global variable to track the current background index
current_bg_index = 0
def get_next_background():
"""Get the next background image from the list and increment the index"""
global current_bg_index
# Load the background image
try:
bg_path = BGR_IMAGE_PATHS[current_bg_index]
background = Image.open(bg_path)
print(f"Using background: {bg_path}")
# Increment the index for next time, wrapping around if needed
current_bg_index = (current_bg_index + 1) % len(BGR_IMAGE_PATHS)
return background
except Exception as e:
print(f"Failed to load background image: {bg_path}")
print(f"Error: {e}")
exit("Stopping")
def process_image(image_path):
print(f"Processing image: {image_path}")
# Skip if the image filename already contains the SUFFIX
if SUFFIX in image_path.stem:
print(f"Skipping already processed image: {image_path}")
return False
"""Remove background from image and replace with specified background"""
try:
# Read the image with PIL
input_image = Image.open(str(image_path))
if input_image is None:
print(f"Failed to read image: {image_path}")
return False
# Resize large images to reduce memory usage
max_dimension = 1000 # Set a reasonable maximum dimension
if input_image.width > max_dimension or input_image.height > max_dimension:
# Calculate new dimensions while preserving aspect ratio
if input_image.width > input_image.height:
new_width = max_dimension
new_height = int(input_image.height * (max_dimension / input_image.width))
else:
new_height = max_dimension
new_width = int(input_image.width * (max_dimension / input_image.height))
print(f"Resizing image from {input_image.width}x{input_image.height} to {new_width}x{new_height}")
input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
# Remove background with alpha matting - adjusted parameters to avoid performance warning
output_transparent = rembg.remove(
input_image,
session=model,
alpha_matting=True,
alpha_matting_foreground_threshold=230, # Slightly reduced from 240
alpha_matting_background_threshold=20, # Slightly increased from 10
alpha_matting_erode_size=5, # Reduced from 10
post_process_mask=True # Add post-processing to improve results
)
# Explicitly delete the input image to free memory
del input_image
# Get the next background image
background_image = get_next_background()
# Resize background to match the output image size
bg_resized = background_image.resize(output_transparent.size)
# Delete the original background image to free memory
del background_image
# Composite the transparent output onto the background
output = Image.new("RGBA", output_transparent.size)
output.paste(bg_resized, (0, 0))
output.paste(output_transparent, (0, 0), output_transparent)
# Delete intermediate images to free memory
del bg_resized
del output_transparent
# Flip the image horizontally or vertically
if random.random() < 0.5:
output = output.transpose(Image.FLIP_LEFT_RIGHT)
else:
output = output.transpose(Image.FLIP_TOP_BOTTOM)
# Rotate the image 90, 180, or 270 degrees
# Note that this can create black bars on the sides of the image. use only if flipping is not enough.
'''
if random.random() < 0.33:
output = output.rotate(90)
elif random.random() < 0.66:
output = output.rotate(180)
else:
output = output.rotate(270)
'''
# Increase brightness by 0-20 %
brightness_factor = random.uniform(1, 1.2)
output = ImageEnhance.Brightness(output).enhance(brightness_factor)
# Zoom in by 0-20 %
zoom_factor = random.uniform(1, 1.2)
output = output.resize((int(output.width * zoom_factor), int(output.height * zoom_factor)))
# Add conservative color jittering
# Contrast: 0.9-1.1 (10% variation)
contrast_factor = random.uniform(0.9, 1.1)
output = ImageEnhance.Contrast(output).enhance(contrast_factor)
# Saturation: 0.9-1.1 (10% variation)
saturation_factor = random.uniform(0.9, 1.1)
output = ImageEnhance.Color(output).enhance(saturation_factor)
# Add subtle Gaussian noise (0.5-1.5% of pixel range)
if random.random() < 0.5:
output_array = np.array(output)
noise_level = random.uniform(1, 5) # Very conservative noise level
noise = np.random.normal(0, noise_level, output_array.shape)
noisy_array = np.clip(output_array + noise, 0, 255).astype(np.uint8)
output = Image.fromarray(noisy_array)
# Add subtle blur
if random.random() < 0.5:
blur_radius = random.uniform(0.2, 0.5) # Small blur radius for subtle effect
output = output.filter(ImageFilter.GaussianBlur(radius=blur_radius))
# Generate output filename
stem = Path(image_path).stem
suffix = Path(image_path).suffix
output_path = Path(image_path).parent / f"{stem}_{SUFFIX}{suffix}"
# Save the processed image
output = output.convert("RGB") # Convert to RGB for saving as JPG
output.save(str(output_path))
# Delete the output image to free memory
del output
# Force garbage collection
gc.collect()
print(f"Processed: {image_path} -> {output_path}")
# Add a small delay to allow memory to be freed
time.sleep(0.1)
return True
except Exception as e:
print(f"Error processing {image_path}: {e}")
# Force garbage collection
gc.collect()
return False
def process_directory(directory):
print(f"Processing directory: {directory}")
# Get all image files in the directory
image_files = []
for ext in ['.jpg', '.jpeg', '.png']:
image_files.extend(list(directory.glob(f"*{ext.lower()}")))
image_files.extend(list(directory.glob(f"*{ext.upper()}")))
if not image_files:
print(f"No image files found in {directory}")
else:
# Find all processed images (with SUFFIX)
processed_images = [img for img in image_files if SUFFIX in img.stem]
processed_count = len(processed_images)
print(f"Found {processed_count} already processed images in {directory}")
# Extract the base names of source images that have already been augmented
# For example, if we have "image1_augmentedbgr.jpg", the source is "image1.jpg"
already_augmented_sources = set()
for processed_img in processed_images:
# Remove the SUFFIX from the stem to get the original source name
original_stem = processed_img.stem.replace(f"_{SUFFIX}", "")
already_augmented_sources.add(original_stem)
print(f"Found {len(already_augmented_sources)} source images that have already been augmented")
# Calculate how many more images to process
images_to_process = max(0, IMAGES_TO_AUGMENT - processed_count)
if images_to_process <= 0:
print(f"Already have {processed_count} processed images in {directory}, which meets or exceeds the target of {IMAGES_TO_AUGMENT}")
else:
print(f"Need to process {images_to_process} more images to reach the target of {IMAGES_TO_AUGMENT}")
# Filter out images that already contain the SUFFIX or have already been augmented
original_images = []
for img in image_files:
if SUFFIX not in img.stem and img.stem not in already_augmented_sources:
original_images.append(img)
if not original_images:
print(f"No unprocessed source images found in {directory}")
else:
print(f"Found {len(original_images)} unprocessed source images")
# Randomly select images_to_process images if there are more than needed
if len(original_images) > images_to_process:
selected_images = random.sample(original_images, images_to_process)
else:
selected_images = original_images
# Process each selected image
for image_path in selected_images:
process_image(image_path)
# Force garbage collection after each image
gc.collect()
# Process all subdirectories recursively
for subdir in directory.iterdir():
if subdir.is_dir():
process_directory(subdir)
# Settings:
# Number of images to process in each directory
IMAGES_TO_AUGMENT = 3
# Suffix for the output images
SUFFIX = "augmentedbgr"
# Input directory which can have subdirectories
INPUT_DIR = "input"
# Background image paths
BGR_IMAGE_PATHS = [
"./bgrs/cardboard_blur.jpg",
"./bgrs/fabric_black_blur.jpg",
"./bgrs/fabric_blur.jpg",
"./bgrs/grey_blur.jpg",
"./bgrs/mm_blur.jpg",
"./bgrs/paper_blur.jpg",
"./bgrs/plastic_blur.jpg",
"./bgrs/tatar_blur.jpg",
"./bgrs/veronica_blur.jpg",
"./bgrs/white.jpg",
"./bgrs/grey.jpg"
]
# Initialize the rembg model
model = rembg.new_session(model_name="isnet-general-use")
# Check if input directory exists
input_path = Path(INPUT_DIR)
if not input_path.exists():
print(f"Input directory '{INPUT_DIR}' does not exist.")
exit("Stopping")
# Start recursive processing from the input directory
try:
process_directory(input_path)
print("Done")
except Exception as e:
print(f"Error during processing: {e}")
# Force garbage collection before exiting
gc.collect()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment