Skip to content

Instantly share code, notes, and snippets.

@innat
Created June 12, 2022 12:14
Show Gist options
  • Save innat/35ab35329e2ca890a17556384056334b to your computer and use it in GitHub Desktop.
Save innat/35ab35329e2ca890a17556384056334b to your computer and use it in GitHub Desktop.
Channel Shuffle augmentation in Jax library
from functools import partial
import numpy as np
from jax import jit
from jax import random
from jax.experimental import jax2tf
from tensorflow.keras import layers
class RandomChannelShuffle(layers.Layer):
"""Shuffle channels of an input image.
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)` format.
Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)` format.
Args:
groups: Number of groups to divide the input channels. Default 3.
"""
def __init__(self, groups=3, **kwargs):
super().__init__(**kwargs)
self.groups = groups
@partial(jit, static_argnums=0)
def _jax_channel_shuffling(self, images):
batch_size, height, width, num_channels = images.shape
if not num_channels % self.groups == 0:
raise ValueError(
"The number of input channels should be "
"divisible by the number of groups."
f"Received: channels={num_channels}, groups={self.groups}"
)
channels_per_group = num_channels // self.groups
images = images.reshape(-1, height, width, self.groups, channels_per_group)
images = images.transpose([3, 1, 2, 4, 0])
key = random.PRNGKey(np.random.randint(50))
images = random.permutation(key=key, x=images, axis=0)
images = images.transpose([4, 1, 2, 3, 0])
images = images.reshape(-1, height, width, num_channels)
return images
def call(self, images, training=True):
if training:
return jax2tf.convert(
self._jax_channel_shuffling, polymorphic_shapes=("batch, ...")
)(images)
else:
return images
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment