Created
June 13, 2022 11:24
-
-
Save innat/4e89725ccdcd763e0a6ba19216fd60bf to your computer and use it in GitHub Desktop.
layer jax2tf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RandomGrayscale(layers.Layer): | |
"""Grayscale is a preprocessing layer that transforms | |
RGB images to Grayscale images. | |
Input shape: | |
3D (unbatched) or 4D (batched) tensor with shape: | |
`(..., height, width, channels)` format | |
Output shape: | |
3D (unbatched) or 4D (batched) tensor with shape: | |
`(..., height, width, channels)` format | |
""" | |
def __init__(self, output_channel=1, prob=1, **kwargs): | |
super().__init__(**kwargs) | |
self.output_channel = self._check_input_params(output_channel) | |
def _check_input_params(self, output_channels): | |
if output_channels not in [1, 3]: | |
raise ValueError( | |
"Received invalid argument output_channels. " | |
f"output_channels must be in 1 or 3. Got {output_channels}" | |
) | |
return output_channels | |
@partial(jit, static_argnums=0) | |
def _jax_gray_scale(self, images): | |
rgb_weights = jnp.array([0.2989, 0.5870, 0.1140], dtype=images.dtype) | |
grayscale = (rgb_weights * images).sum(axis=-1) | |
if self.output_channel == 1: | |
grayscale = jnp.expand_dims(grayscale, axis=-1) | |
return grayscale | |
elif self.output_channel == 3: | |
return jnp.stack([grayscale] * 3, axis=-1) | |
else: | |
raise ValueError("Unsupported value for `output_channels`.") | |
def call(self, images, training=True): | |
if training: | |
return jax2tf.convert( | |
self._jax_gray_scale, polymorphic_shapes=("batch, ...") | |
)(images) | |
else: | |
return images | |
def get_config(self): | |
config = { | |
"output_channel": self.output_channel, | |
} | |
base_config = super().get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment