Created
June 12, 2022 12:14
-
-
Save innat/35ab35329e2ca890a17556384056334b to your computer and use it in GitHub Desktop.
Channel Shuffle augmentation in Jax library
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import numpy as np | |
from jax import jit | |
from jax import random | |
from jax.experimental import jax2tf | |
from tensorflow.keras import layers | |
class RandomChannelShuffle(layers.Layer): | |
"""Shuffle channels of an input image. | |
Input shape: | |
3D (unbatched) or 4D (batched) tensor with shape: | |
`(..., height, width, channels)` format. | |
Output shape: | |
3D (unbatched) or 4D (batched) tensor with shape: | |
`(..., height, width, channels)` format. | |
Args: | |
groups: Number of groups to divide the input channels. Default 3. | |
""" | |
def __init__(self, groups=3, **kwargs): | |
super().__init__(**kwargs) | |
self.groups = groups | |
@partial(jit, static_argnums=0) | |
def _jax_channel_shuffling(self, images): | |
batch_size, height, width, num_channels = images.shape | |
if not num_channels % self.groups == 0: | |
raise ValueError( | |
"The number of input channels should be " | |
"divisible by the number of groups." | |
f"Received: channels={num_channels}, groups={self.groups}" | |
) | |
channels_per_group = num_channels // self.groups | |
images = images.reshape(-1, height, width, self.groups, channels_per_group) | |
images = images.transpose([3, 1, 2, 4, 0]) | |
key = random.PRNGKey(np.random.randint(50)) | |
images = random.permutation(key=key, x=images, axis=0) | |
images = images.transpose([4, 1, 2, 3, 0]) | |
images = images.reshape(-1, height, width, num_channels) | |
return images | |
def call(self, images, training=True): | |
if training: | |
return jax2tf.convert( | |
self._jax_channel_shuffling, polymorphic_shapes=("batch, ...") | |
)(images) | |
else: | |
return images |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment