Skip to content

Instantly share code, notes, and snippets.

@innat
Created May 29, 2023 15:18
Show Gist options
  • Save innat/47241247a5784aa3381b7610c560ccb2 to your computer and use it in GitHub Desktop.
Save innat/47241247a5784aa3381b7610c560ccb2 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow import keras
class ChannelShuffle(keras.layers.Layer):
def __init__(self, groups=3, seed=None, **kwargs):
super().__init__(**kwargs)
self.groups = groups
self.seed = seed
def _channel_shuffling(self, images):
height = tf.shape(images)[1]
width = tf.shape(images)[2]
num_channels = images.shape[3]
channels_per_group = num_channels // self.groups
images = tf.reshape(
images, [-1, height, width, self.groups, channels_per_group]
)
images = tf.transpose(images, perm=[3, 1, 2, 4, 0])
images = tf.random.shuffle(images, seed=self.seed)
images = tf.transpose(images, perm=[4, 1, 2, 3, 0])
images = tf.reshape(images, [-1, height, width, num_channels])
return images
def call(self, images, training=True):
if training:
return self._channel_shuffling(images)
else:
return images
def get_config(self):
config = super().get_config()
config.update({"groups": self.groups, "seed": self.seed})
return config
images = tf.ones(shape=(5, 224, 224, 3))
ChannelShuffle(groups=3)(images)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment