Created
June 12, 2022 12:02
-
-
Save innat/0524ee77de17f0601f0dee69aa52c713 to your computer and use it in GitHub Desktop.
Vectorized Implementation of CutMix Augmentation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras import layers | |
class CutMix(layers.Layer): | |
"""Original implementation: https://github.com/keras-team/keras-cv. | |
The original implementaiton provide more interface to apply mixup on | |
various CV related task, i.e. object detection etc. It also provides | |
many effective validation check. | |
Deried and modified for simpler usages: M.Innat. | |
""" | |
def __init__(self, alpha=1.0, seed=None, **kwargs): | |
super().__init__(**kwargs) | |
self.alpha = alpha | |
self.seed = seed | |
@staticmethod | |
def _sample_from_beta(alpha, beta, shape): | |
sample_alpha = tf.random.gamma(shape, 1.0, beta=alpha) | |
sample_beta = tf.random.gamma(shape, 1.0, beta=beta) | |
return sample_alpha / (sample_alpha + sample_beta) | |
def _cutmix_labels(self, labels, lambda_sample, permutation_order): | |
cutout_labels = tf.gather(labels, permutation_order) | |
lambda_sample = tf.reshape(lambda_sample, [-1, 1]) | |
labels = lambda_sample * labels + (1.0 - lambda_sample) * cutout_labels | |
return labels | |
def _cutmix_samples(self, images): | |
input_shape = tf.shape(images) | |
batch_size, image_height, image_width = ( | |
input_shape[0], | |
input_shape[1], | |
input_shape[2], | |
) | |
permutation_order = tf.random.shuffle(tf.range(0, batch_size), seed=self.seed) | |
lambda_sample = CutMix._sample_from_beta(self.alpha, self.alpha, (batch_size,)) | |
ratio = tf.math.sqrt(1 - lambda_sample) | |
cut_height = tf.cast( | |
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32 | |
) | |
cut_width = tf.cast( | |
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32 | |
) | |
random_center_height = tf.random.uniform( | |
shape=[batch_size], minval=0, maxval=image_height, dtype=tf.int32 | |
) | |
random_center_width = tf.random.uniform( | |
shape=[batch_size], minval=0, maxval=image_width, dtype=tf.int32 | |
) | |
bounding_box_area = cut_height * cut_width | |
lambda_sample = 1.0 - bounding_box_area / (image_height * image_width) | |
lambda_sample = tf.cast(lambda_sample, dtype=tf.float32) | |
images = self.fill_rectangle( | |
images, | |
random_center_width, | |
random_center_height, | |
cut_width, | |
cut_height, | |
tf.gather(images, permutation_order), | |
) | |
return images, lambda_sample, permutation_order | |
def call(self, batch_inputs, training=None): | |
bs_images = tf.cast(batch_inputs[0], dtype=tf.float32) # ALL Image Samples | |
bs_labels = tf.cast(batch_inputs[1], dtype=tf.float32) # ALL Lable Samples | |
cutmix_images, lambda_sample, permutation_order = self._cutmix_samples( | |
bs_images | |
) | |
cutmix_labels = self._cutmix_labels(bs_labels, lambda_sample, permutation_order) | |
return [cutmix_images, cutmix_labels] | |
def fill_rectangle( | |
self, images, centers_x, centers_y, widths, heights, fill_values | |
): | |
images_shape = tf.shape(images) | |
images_height = images_shape[1] | |
images_width = images_shape[2] | |
xywh = tf.stack([centers_x, centers_y, widths, heights], axis=1) | |
xywh = tf.cast(xywh, tf.float32) | |
corners = self.convert_format(xywh) | |
mask_shape = (images_width, images_height) | |
is_rectangle = self.corners_to_mask(corners, mask_shape) | |
is_rectangle = tf.expand_dims(is_rectangle, -1) | |
images = tf.where(is_rectangle, fill_values, images) | |
return images | |
def convert_format(self, boxes): | |
boxes = tf.cast(boxes, dtype=tf.float32) | |
x, y, width, height, rest = tf.split(boxes, [1, 1, 1, 1, -1], axis=-1) | |
results = tf.concat( | |
[ | |
x - width / 2.0, | |
y - height / 2.0, | |
x + width / 2.0, | |
y + height / 2.0, | |
rest, | |
], | |
axis=-1, | |
) | |
return results | |
def _axis_mask(self, starts, ends, mask_len): | |
# index range of axis | |
batch_size = tf.shape(starts)[0] | |
axis_indices = tf.range(mask_len, dtype=starts.dtype) | |
axis_indices = tf.expand_dims(axis_indices, 0) | |
axis_indices = tf.tile(axis_indices, [batch_size, 1]) | |
# mask of index bounds | |
axis_mask = tf.greater_equal(axis_indices, starts) & tf.less(axis_indices, ends) | |
return axis_mask | |
def corners_to_mask(self, bounding_boxes, mask_shape): | |
mask_width, mask_height = mask_shape | |
x0, y0, x1, y1 = tf.split(bounding_boxes, [1, 1, 1, 1], axis=-1) | |
w_mask = self._axis_mask(x0, x1, mask_width) | |
h_mask = self._axis_mask(y0, y1, mask_height) | |
w_mask = tf.expand_dims(w_mask, axis=1) | |
h_mask = tf.expand_dims(h_mask, axis=2) | |
masks = tf.logical_and(w_mask, h_mask) | |
return masks |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment