Skip to content

Instantly share code, notes, and snippets.

@chamecall
Created March 25, 2025 13:19
Show Gist options
  • Save chamecall/cee5dc2ad31a49a75e658fde50388e62 to your computer and use it in GitHub Desktop.
Save chamecall/cee5dc2ad31a49a75e658fde50388e62 to your computer and use it in GitHub Desktop.
custom image/bbox albumentation augmentation
import numpy as np
import cv2
from albumentations.core.transforms_interface import DualTransform
class TopBiasedRandomCrop(DualTransform):
"""
Randomly crops an image so that the cropped region's width is at least min_width and its height is at least min_height
of the original. The vertical (y) offset is biased by sampling from a Beta distribution with parameters beta_alpha and beta_beta.
This transform also adjusts bounding boxes in [x_min, y_min, x_max, y_max] format.
Args:
min_width (float): Minimum relative width of the crop (0 < min_width <= 1).
min_height (float): Minimum relative height of the crop (0 < min_height <= 1).
beta_alpha (float): Alpha parameter for the Beta distribution (vertical bias).
beta_beta (float): Beta parameter for the Beta distribution.
p (float): Probability of applying the transform.
"""
def __init__(self, min_width=0.7, min_height=0.5, beta_alpha=1.0, beta_beta=2.0, p=1.0):
super(TopBiasedRandomCrop, self).__init__(p)
if not (0 < min_width <= 1):
raise ValueError("min_width must be in the interval (0, 1].")
if not (0 < min_height <= 1):
raise ValueError("min_height must be in the interval (0, 1].")
self.min_width = min_width
self.min_height = min_height
self.beta_alpha = beta_alpha
self.beta_beta = beta_beta
def get_params_dependent_on_data(self, params, data):
return self.get_params_dependent_on_targets({"image": data["image"]})
def get_params_dependent_on_targets(self, params) -> dict:
img = params["image"]
height, width = img.shape[:2]
# Determine crop dimensions.
crop_width = int(np.random.uniform(self.min_width, 1.0) * width)
crop_height = int(np.random.uniform(self.min_height, 1.0) * height)
crop_width = min(crop_width, width)
crop_height = min(crop_height, height)
# Maximum possible offsets.
x_max = width - crop_width
y_max = height - crop_height
x1 = np.random.randint(0, x_max + 1) if x_max > 0 else 0
y_sample = np.random.beta(self.beta_alpha, self.beta_beta)
y1 = int(y_sample * y_max) if y_max > 0 else 0
crop_params = [x1, y1, x1 + crop_width, y1 + crop_height]
# Return crop_params plus update new shape info (so that bbox filtering uses the cropped dimensions).
return {"crop_params": crop_params, "rows": crop_height, "cols": crop_width}
def apply(self, img, **params):
crop_params = params.get("crop_params")
if crop_params is None:
return img
x1, y1, x2, y2 = crop_params
cropped = img[y1:y2, x1:x2]
return cropped
def apply_to_bbox(self, bbox, **params):
crop_params = params.get("crop_params")
if crop_params is None:
return bbox
x1, y1, x2, y2 = crop_params
new_bbox = [
np.clip(bbox[0] - x1, 0, x2 - x1),
np.clip(bbox[1] - y1, 0, y2 - y1),
np.clip(bbox[2] - x1, 0, x2 - x1),
np.clip(bbox[3] - y1, 0, y2 - y1)
]
if len(bbox) > 4:
new_bbox.extend(bbox[4:])
return new_bbox
def apply_to_bboxes(self, bboxes, **params):
transformed = [self.apply_to_bbox(bbox, **params) for bbox in bboxes]
# Convert to NumPy array so that further processing (e.g., filtering) works.
return np.array(transformed, dtype=np.float32)
def get_transform_init_args_names(self):
return ("min_width", "min_height", "beta_alpha", "beta_beta")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment