Skip to content

Instantly share code, notes, and snippets.

@innat
Created June 12, 2022 12:01
Show Gist options
  • Save innat/0ee2b6155d663aac2617fe596e1d8d49 to your computer and use it in GitHub Desktop.
Save innat/0ee2b6155d663aac2617fe596e1d8d49 to your computer and use it in GitHub Desktop.
Vectorized Implementation of MixUp Augmentation.
import tensorflow as tf
from tensorflow.keras import layers
class MixUp(layers.Layer):
"""Original implementation: https://github.com/keras-team/keras-cv.
The original implementaiton provide more interface to apply mixup on
various CV related task, i.e. object detection etc. It also provides
many effective validation check.
Deried and modified for simpler usages: M.Innat.
"""
def __init__(self, alpha=0.2, seed=None, **kwargs):
super().__init__(**kwargs)
self.alpha = alpha
self.seed = seed
@staticmethod
def _sample_from_beta(alpha, beta, shape):
sample_alpha = tf.random.gamma(shape, 1.0, beta=alpha)
sample_beta = tf.random.gamma(shape, 1.0, beta=beta)
return sample_alpha / (sample_alpha + sample_beta)
def _mixup_samples(self, images):
batch_size = tf.shape(images)[0]
permutation_order = tf.random.shuffle(tf.range(0, batch_size), seed=self.seed)
lambda_sample = MixUp._sample_from_beta(self.alpha, self.alpha, (batch_size,))
lambda_sample = tf.reshape(lambda_sample, [-1, 1, 1, 1])
mixup_images = tf.gather(images, permutation_order)
images = lambda_sample * images + (1.0 - lambda_sample) * mixup_images
return images, tf.squeeze(lambda_sample), permutation_order
def _mixup_labels(self, labels, lambda_sample, permutation_order):
labels_for_mixup = tf.gather(labels, permutation_order)
lambda_sample = tf.reshape(lambda_sample, [-1, 1])
labels = lambda_sample * labels + (1.0 - lambda_sample) * labels_for_mixup
return labels
def call(self, batch_inputs):
bs_images = tf.cast(batch_inputs[0], dtype=tf.float32) # ALL Image Samples
bs_labels = tf.cast(batch_inputs[1], dtype=tf.float32) # ALL Lable Samples
mixup_images, lambda_sample, permutation_order = self._mixup_samples(bs_images)
mixup_labels = self._mixup_labels(bs_labels, lambda_sample, permutation_order)
return [mixup_images, mixup_labels]
def get_config(self):
config = super().get_config()
config.update(
{
"alpha": self.alpha,
"seed": self.seed,
}
)
return config
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment