Skip to content

Instantly share code, notes, and snippets.

@innat
Created June 12, 2022 12:02
Show Gist options
  • Save innat/0524ee77de17f0601f0dee69aa52c713 to your computer and use it in GitHub Desktop.
Save innat/0524ee77de17f0601f0dee69aa52c713 to your computer and use it in GitHub Desktop.
Vectorized Implementation of CutMix Augmentation.
import tensorflow as tf
from tensorflow.keras import layers
class CutMix(layers.Layer):
"""Original implementation: https://github.com/keras-team/keras-cv.
The original implementaiton provide more interface to apply mixup on
various CV related task, i.e. object detection etc. It also provides
many effective validation check.
Deried and modified for simpler usages: M.Innat.
"""
def __init__(self, alpha=1.0, seed=None, **kwargs):
super().__init__(**kwargs)
self.alpha = alpha
self.seed = seed
@staticmethod
def _sample_from_beta(alpha, beta, shape):
sample_alpha = tf.random.gamma(shape, 1.0, beta=alpha)
sample_beta = tf.random.gamma(shape, 1.0, beta=beta)
return sample_alpha / (sample_alpha + sample_beta)
def _cutmix_labels(self, labels, lambda_sample, permutation_order):
cutout_labels = tf.gather(labels, permutation_order)
lambda_sample = tf.reshape(lambda_sample, [-1, 1])
labels = lambda_sample * labels + (1.0 - lambda_sample) * cutout_labels
return labels
def _cutmix_samples(self, images):
input_shape = tf.shape(images)
batch_size, image_height, image_width = (
input_shape[0],
input_shape[1],
input_shape[2],
)
permutation_order = tf.random.shuffle(tf.range(0, batch_size), seed=self.seed)
lambda_sample = CutMix._sample_from_beta(self.alpha, self.alpha, (batch_size,))
ratio = tf.math.sqrt(1 - lambda_sample)
cut_height = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32
)
cut_width = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32
)
random_center_height = tf.random.uniform(
shape=[batch_size], minval=0, maxval=image_height, dtype=tf.int32
)
random_center_width = tf.random.uniform(
shape=[batch_size], minval=0, maxval=image_width, dtype=tf.int32
)
bounding_box_area = cut_height * cut_width
lambda_sample = 1.0 - bounding_box_area / (image_height * image_width)
lambda_sample = tf.cast(lambda_sample, dtype=tf.float32)
images = self.fill_rectangle(
images,
random_center_width,
random_center_height,
cut_width,
cut_height,
tf.gather(images, permutation_order),
)
return images, lambda_sample, permutation_order
def call(self, batch_inputs, training=None):
bs_images = tf.cast(batch_inputs[0], dtype=tf.float32) # ALL Image Samples
bs_labels = tf.cast(batch_inputs[1], dtype=tf.float32) # ALL Lable Samples
cutmix_images, lambda_sample, permutation_order = self._cutmix_samples(
bs_images
)
cutmix_labels = self._cutmix_labels(bs_labels, lambda_sample, permutation_order)
return [cutmix_images, cutmix_labels]
def fill_rectangle(
self, images, centers_x, centers_y, widths, heights, fill_values
):
images_shape = tf.shape(images)
images_height = images_shape[1]
images_width = images_shape[2]
xywh = tf.stack([centers_x, centers_y, widths, heights], axis=1)
xywh = tf.cast(xywh, tf.float32)
corners = self.convert_format(xywh)
mask_shape = (images_width, images_height)
is_rectangle = self.corners_to_mask(corners, mask_shape)
is_rectangle = tf.expand_dims(is_rectangle, -1)
images = tf.where(is_rectangle, fill_values, images)
return images
def convert_format(self, boxes):
boxes = tf.cast(boxes, dtype=tf.float32)
x, y, width, height, rest = tf.split(boxes, [1, 1, 1, 1, -1], axis=-1)
results = tf.concat(
[
x - width / 2.0,
y - height / 2.0,
x + width / 2.0,
y + height / 2.0,
rest,
],
axis=-1,
)
return results
def _axis_mask(self, starts, ends, mask_len):
# index range of axis
batch_size = tf.shape(starts)[0]
axis_indices = tf.range(mask_len, dtype=starts.dtype)
axis_indices = tf.expand_dims(axis_indices, 0)
axis_indices = tf.tile(axis_indices, [batch_size, 1])
# mask of index bounds
axis_mask = tf.greater_equal(axis_indices, starts) & tf.less(axis_indices, ends)
return axis_mask
def corners_to_mask(self, bounding_boxes, mask_shape):
mask_width, mask_height = mask_shape
x0, y0, x1, y1 = tf.split(bounding_boxes, [1, 1, 1, 1], axis=-1)
w_mask = self._axis_mask(x0, x1, mask_width)
h_mask = self._axis_mask(y0, y1, mask_height)
w_mask = tf.expand_dims(w_mask, axis=1)
h_mask = tf.expand_dims(h_mask, axis=2)
masks = tf.logical_and(w_mask, h_mask)
return masks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment