Skip to content

Instantly share code, notes, and snippets.

@jfsantos-ds
Last active January 3, 2022 11:27
Show Gist options
  • Save jfsantos-ds/c3b43f8a3aad1f662512f716b0568152 to your computer and use it in GitHub Desktop.
Save jfsantos-ds/c3b43f8a3aad1f662512f716b0568152 to your computer and use it in GitHub Desktop.
Synthetic data with Gumbel-Softmax activations
ID Gender_Male Gender_Female AgeRange_10-19 AgeRange_20-29
1 1 0 0 1
2 0 1 1 0
ID Gender AgeRange
1 Male 20-29
2 Female 10-19
class Generator(tf.keras.Model):
def __init__(self, batch_size):
self.batch_size = batch_size
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None):
input = Input(shape=input_shape, batch_size=self.batch_size)
x = Dense(dim, activation='relu')(input)
x = Dense(dim * 2, activation='relu')(x)
x = Dense(dim * 4, activation='relu')(x)
x = Dense(data_dim)(x)
if activation_info:
x = GumbelSoftmaxActivation(activation_info)(x)
return Model(inputs=input, outputs=x)
from typing import Optional
from tensorflow import Tensor, TensorShape, one_hot, squeeze, stop_gradient
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils import register_keras_serializable
from tensorflow.math import log
from tensorflow.nn import softmax
from tensorflow.random import categorical, uniform
TOL = 1e-20
def gumbel_noise(shape: TensorShape) -> Tensor:
"""Create a single sample from the standard (loc = 0, scale = 1) Gumbel distribution."""
uniform_sample = uniform(shape, seed=0)
return -log(-log(uniform_sample + TOL) + TOL)
@register_keras_serializable(package='Synthetic Data', name='GumbelSoftmaxLayer')
class GumbelSoftmaxLayer(Layer):
"A Gumbel-Softmax layer implementation that should be stacked on top of a categorical feature logits."
def __init__(self, tau: float = 0.2, name: Optional[str] = None, **kwargs):
super().__init__(name=name, **kwargs)
self.tau = tau
def call(self, _input):
"""Computes Gumbel-Softmax for the logits output of a particular categorical feature."""
noised_input = _input + gumbel_noise(_input.shape)
soft_sample = softmax(noised_input/self.tau, -1)
hard_sample = stop_gradient(squeeze(one_hot(categorical(log(soft_sample), 1), _input.shape[-1]), 1))
return hard_sample, soft_sample
def get_config(self):
config = super().get_config().copy()
config.update({'tau': self.tau})
return config
ID Gender_Male Gender_Female AgeRange_10-19 AgeRange_20-29
1 0.867 0.622 -0.155 0.855
2 0.032 1.045 0.901 -0.122
ID Gender_Male Gender_Female AgeRange_10-19 AgeRange_20-29
1 0.561 0.439 0.267 0.733
2 0.266 0.734 0.736 0.264
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment