Created
January 23, 2024 10:02
-
-
Save innat/b6ede34630e4a2988c968467f6d3facb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras import layers | |
H_AXIS = -3 | |
W_AXIS = -2 | |
class RandomCutout(layers.Layer): | |
"""Randomly cut out rectangles from images and fill them. | |
Args: | |
height_factor: A a single float. `height_factor` controls the size of the | |
cutouts. `height_factor=0.0` means the rectangle will be of size 0% | |
of the image height, `height_factor=0.1` means the rectangle will | |
have a size of 10% of the image height, and so forth. | |
width_factor: A a single float. `width_factor` controls the size of the | |
cutouts. `height_factor=0.0` means the rectangle will be of size 0% | |
of the image height, `height_factor=0.1` means the rectangle will | |
have a size of 10% of the image height, and so forth. | |
fill_mode: Pixels inside the patches are filled according to the given | |
mode (one of `{"constant", "gaussian_noise"}`). | |
- *constant*: Pixels are filled with the same constant value. | |
- *gaussian_noise*: Pixels are filled with random gaussian noise. | |
fill_value: a float represents the value to be filled inside the patches | |
when `fill_mode="constant"`. | |
seed: Integer. Used to create a random seed. | |
Sample usage: | |
```python | |
(images, labels), _ = load_data() | |
random_cutout = RandomCutout(0.5, 0.5) | |
augmented_images = random_cutout(images) | |
``` | |
# Disclaimer | |
Original implementation: https://github.com/keras-team/keras-cv. | |
The original implementaiton provide more interface to apply mixup on | |
various CV related task, i.e. object detection etc. It also provides | |
many effective validation check. | |
Derived and modified for simpler usages: M.Innat. | |
Ref. https://gist.github.com/innat/b6ede34630e4a2988c968467f6d3facb | |
""" | |
def __init__( | |
self, | |
height_factor, | |
width_factor, | |
fill_mode="constant", | |
fill_value=0.0, | |
seed=None, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.height_factor = height_factor | |
self.width_factor = width_factor, | |
self.fill_mode = fill_mode | |
self.fill_value = fill_value | |
self.seed = seed | |
if fill_mode not in ["gaussian_noise", "constant"]: | |
raise ValueError( | |
'`fill_mode` should be "gaussian_noise" ' | |
f'or "constant". Got `fill_mode`={fill_mode}' | |
) | |
def get_random_transformation_batch(self, images, **kwargs): | |
centers_x, centers_y = self._compute_rectangle_position(images) | |
rectangles_height, rectangles_width = self._compute_rectangle_size( | |
images | |
) | |
return { | |
"centers_x": centers_x, | |
"centers_y": centers_y, | |
"rectangles_height": rectangles_height, | |
"rectangles_width": rectangles_width, | |
} | |
def fill_rectangle(self, images, centers_x, centers_y, widths, heights, fill_values): | |
"""Fill rectangles with fill value into images. | |
Args: | |
images: Tensor of images to fill rectangles into | |
centers_x: Tensor of positions of the rectangle centers on the x-axis | |
centers_y: Tensor of positions of the rectangle centers on the y-axis | |
widths: Tensor of widths of the rectangles | |
heights: Tensor of heights of the rectangles | |
fill_values: Tensor with same shape as images to get rectangle fill from | |
Returns: | |
images with filled rectangles. | |
""" | |
images_shape = tf.shape(images) | |
images_height = images_shape[1] | |
images_width = images_shape[2] | |
xywh = tf.stack([centers_x, centers_y, widths, heights], axis=1) | |
xywh = tf.cast(xywh, tf.float32) | |
corners = self.convert_format( | |
xywh | |
) | |
mask_shape = (images_width, images_height) | |
is_rectangle = self.corners_to_mask(corners, mask_shape) | |
is_rectangle = tf.expand_dims(is_rectangle, -1) | |
images = tf.where(is_rectangle, fill_values, images) | |
return images | |
def convert_format(self, boxes): | |
boxes = tf.cast(boxes, dtype=tf.float32) | |
x, y, width, height, rest = tf.split(boxes, [1, 1, 1, 1, -1], axis=-1) | |
results = tf.concat( | |
[ | |
x - width / 2.0, | |
y - height / 2.0, | |
x + width / 2.0, | |
y + height / 2.0, | |
rest, | |
], | |
axis=-1, | |
) | |
return results | |
def call(self, images, **kwargs): | |
transformations = self.get_random_transformation_batch( | |
images, **kwargs | |
) | |
"""Apply random cutout.""" | |
centers_x, centers_y = ( | |
transformations["centers_x"], | |
transformations["centers_y"], | |
) | |
rectangles_height, rectangles_width = ( | |
transformations["rectangles_height"], | |
transformations["rectangles_width"], | |
) | |
rectangles_fill = self._compute_rectangle_fill(images) | |
images = self.fill_rectangle( | |
images, | |
centers_x, | |
centers_y, | |
rectangles_width, | |
rectangles_height, | |
rectangles_fill, | |
) | |
return images | |
def _get_image_shape(self, images): | |
batch_size = tf.shape(images)[0] | |
heights = tf.repeat(tf.shape(images)[H_AXIS], repeats=[batch_size]) | |
heights = tf.reshape(heights, shape=(-1,)) | |
widths = tf.repeat(tf.shape(images)[W_AXIS], repeats=[batch_size]) | |
widths = tf.reshape(widths, shape=(-1,)) | |
return tf.cast(heights, dtype=tf.int32), tf.cast(widths, dtype=tf.int32) | |
def _compute_rectangle_position(self, inputs): | |
batch_size = tf.shape(inputs)[0] | |
heights, widths = self._get_image_shape(inputs) | |
# generate values in float32 and then cast (i.e. round) to int32 because | |
# random.uniform do not support maxval broadcasting for integer types. | |
# Needed because maxval is a 1-D tensor to support ragged inputs. | |
heights = tf.cast(heights, dtype=tf.float32) | |
widths = tf.cast(widths, dtype=tf.float32) | |
center_x = tf.random.uniform( | |
(batch_size,), 0, widths, dtype=tf.float32 | |
) | |
center_y = tf.random.uniform( | |
(batch_size,), 0, heights, dtype=tf.float32 | |
) | |
center_x = tf.cast(center_x, tf.int32) | |
center_y = tf.cast(center_y, tf.int32) | |
return center_x, center_y | |
def _compute_rectangle_size(self, inputs): | |
batch_size = tf.shape(inputs)[0] | |
images_heights, images_widths = self._get_image_shape(inputs) | |
height = self.height_factor | |
width = self.width_factor | |
height = height * tf.cast(images_heights, tf.float32) | |
width = width * tf.cast(images_widths, tf.float32) | |
height = tf.cast(tf.math.ceil(height), tf.int32) | |
width = tf.cast(tf.math.ceil(width), tf.int32) | |
height = tf.minimum(height, images_heights) | |
width = tf.minimum(width, images_heights) | |
return height, width | |
def _compute_rectangle_fill(self, inputs): | |
input_shape = tf.shape(inputs) | |
if self.fill_mode == "constant": | |
fill_value = tf.fill(input_shape, self.fill_value) | |
fill_value = tf.cast(fill_value, dtype=self.compute_dtype) | |
else: | |
# gaussian noise | |
fill_value = tf.random.normal(input_shape, dtype=self.compute_dtype) | |
# rescale the random noise to the original image range | |
image_max = tf.reduce_max(inputs) | |
image_min = tf.reduce_min(inputs) | |
fill_max = tf.reduce_max(fill_value) | |
fill_min = tf.reduce_min(fill_value) | |
fill_value = (image_max - image_min) * (fill_value - fill_min) / ( | |
fill_max - fill_min | |
) + image_min | |
return fill_value | |
def _axis_mask(self, starts, ends, mask_len): | |
# index range of axis | |
batch_size = tf.shape(starts)[0] | |
axis_indices = tf.range(mask_len, dtype=starts.dtype) | |
axis_indices = tf.expand_dims(axis_indices, 0) | |
axis_indices = tf.tile(axis_indices, [batch_size, 1]) | |
# mask of index bounds | |
axis_mask = tf.greater_equal(axis_indices, starts) & tf.less(axis_indices, ends) | |
return axis_mask | |
def corners_to_mask(self, bounding_boxes, mask_shape): | |
mask_width, mask_height = mask_shape | |
x0, y0, x1, y1 = tf.split(bounding_boxes, [1, 1, 1, 1], axis=-1) | |
w_mask = self._axis_mask(x0, x1, mask_width) | |
h_mask = self._axis_mask(y0, y1, mask_height) | |
w_mask = tf.expand_dims(w_mask, axis=1) | |
h_mask = tf.expand_dims(h_mask, axis=2) | |
masks = tf.logical_and(w_mask, h_mask) | |
return masks | |
def get_config(self): | |
config = super().get_config() | |
config.update( | |
{ | |
"height_factor": self.height_factor, | |
"width_factor": self.width_factor, | |
"fill_mode": self.fill_mode, | |
"fill_value": self.fill_value, | |
"seed": self.seed, | |
} | |
) | |
return config |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment