Created
May 29, 2023 15:21
-
-
Save innat/03e5bc8253053bf05d03c05c4786894c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow import keras | |
class RGBShift(keras.layers.Layer): | |
"""RGBShift class randomly shift values for each channel of the input RGB image. | |
""" | |
def __init__( | |
self, | |
factor, | |
seed=None, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.factor = self._set_shift_limit(factor) | |
self.seed = seed | |
def _set_shift_limit(self, factor): | |
if isinstance(factor, (tuple, list)): | |
if len(factor) != 2: | |
raise ValueError( | |
'The factor should be scalar' | |
'tuple or list of two upper and lower' | |
f'bound number. Got {factor}' | |
) | |
return self._check_factor_range(sorted(factor)) | |
elif isinstance(factor, (int, float)): | |
factor = abs(factor) | |
return self._check_factor_range([-factor, factor]) | |
else: | |
raise ValueError( | |
'The factor should be scalar' | |
f'tuple or list of two upper and lower bound umber. Got {factor}' | |
) | |
@staticmethod | |
def _check_factor_range(factor): | |
if all(isinstance(each_elem, float) for each_elem in factor): | |
if factor[0] < -1.0 or factor[1] > 1.0: | |
raise ValueError(f"Got {factor}") | |
return factor | |
elif all(isinstance(each_elem, int) for each_elem in factor): | |
if factor[0] < -255 or factor[1] > 255: | |
raise ValueError(f"Got {factor}") | |
return factor | |
else: | |
raise ValueError(f'Both bound must be same dtype. Got {factor}') | |
def _get_random_uniform(self, shift_limit, rgb_delta_shape): | |
if self.seed is not None: | |
_rand_uniform = tf.random.stateless_uniform( | |
shape=rgb_delta_shape, | |
seed=[0, self.seed], | |
minval=shift_limit[0], | |
maxval=shift_limit[1], | |
) | |
else: | |
_rand_uniform = tf.random.uniform( | |
rgb_delta_shape, | |
minval=shift_limit[0], | |
maxval=shift_limit[1], | |
dtype=tf.float32 | |
) | |
if all(isinstance(each_elem, float) for each_elem in shift_limit): | |
_rand_uniform = _rand_uniform * 85.0 | |
return _rand_uniform | |
def _rgb_shifting(self, images): | |
rank = images.shape.rank | |
original_dtype = images.dtype | |
if rank == 3: | |
rgb_delta_shape = (1, 1) | |
elif rank == 4: | |
# Keep only the batch dim. This will ensure to have same adjustment | |
# with in one image, but different across the images. | |
rgb_delta_shape = [tf.shape(images)[0], 1, 1] | |
else: | |
raise ValueError( | |
f"Expect the input image to be rank 3 or 4. Got {images.shape}" | |
) | |
r_shift = self._get_random_uniform(self.factor, rgb_delta_shape) | |
g_shift = self._get_random_uniform(self.factor, rgb_delta_shape) | |
b_shift = self._get_random_uniform(self.factor, rgb_delta_shape) | |
unstack_rgb = tf.unstack(tf.cast(images, dtype=tf.float32), axis=-1) | |
shifted_rgb = tf.stack( | |
[ | |
tf.add(unstack_rgb[0], r_shift), | |
tf.add(unstack_rgb[1], g_shift), | |
tf.add(unstack_rgb[2], b_shift) | |
], axis=-1 | |
) | |
shifted_rgb = tf.clip_by_value(shifted_rgb, 0.0, 255.0) | |
return tf.cast(shifted_rgb, dtype=original_dtype) | |
def call(self, images, training=True): | |
return self._rgb_shifting(images) | |
def get_config(self): | |
config = super().get_config() | |
config.update( | |
{ | |
"factor": self.factor, | |
"seed": self.seed | |
} | |
) | |
return config | |
def compute_output_shape(self, input_shape): | |
return input_shape | |
images = tf.ones(shape=(5, 224, 224, 3)) | |
rgbshift_images = RGBShift(factor=(-120, 120))(images) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment