Skip to content

Instantly share code, notes, and snippets.

@innat
Created June 13, 2022 11:24
Show Gist options
  • Save innat/4e89725ccdcd763e0a6ba19216fd60bf to your computer and use it in GitHub Desktop.
Save innat/4e89725ccdcd763e0a6ba19216fd60bf to your computer and use it in GitHub Desktop.
layer jax2tf
class RandomGrayscale(layers.Layer):
"""Grayscale is a preprocessing layer that transforms
RGB images to Grayscale images.
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)` format
Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)` format
"""
def __init__(self, output_channel=1, prob=1, **kwargs):
super().__init__(**kwargs)
self.output_channel = self._check_input_params(output_channel)
def _check_input_params(self, output_channels):
if output_channels not in [1, 3]:
raise ValueError(
"Received invalid argument output_channels. "
f"output_channels must be in 1 or 3. Got {output_channels}"
)
return output_channels
@partial(jit, static_argnums=0)
def _jax_gray_scale(self, images):
rgb_weights = jnp.array([0.2989, 0.5870, 0.1140], dtype=images.dtype)
grayscale = (rgb_weights * images).sum(axis=-1)
if self.output_channel == 1:
grayscale = jnp.expand_dims(grayscale, axis=-1)
return grayscale
elif self.output_channel == 3:
return jnp.stack([grayscale] * 3, axis=-1)
else:
raise ValueError("Unsupported value for `output_channels`.")
def call(self, images, training=True):
if training:
return jax2tf.convert(
self._jax_gray_scale, polymorphic_shapes=("batch, ...")
)(images)
else:
return images
def get_config(self):
config = {
"output_channel": self.output_channel,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment