Skip to content

Instantly share code, notes, and snippets.

@cgarciae
Created August 28, 2020 03:26
Show Gist options
  • Save cgarciae/63525c7b95ddde4fc0a147423314c939 to your computer and use it in GitHub Desktop.
Save cgarciae/63525c7b95ddde4fc0a147423314c939 to your computer and use it in GitHub Desktop.
Module
class MixtureModel(elegy.Module):
def __init__(self, k: int):
super().__init__()
self.k = k
def call(self, x):
x = elegy.nn.Linear(64, name="backbone")(x)
x = jax.nn.relu(x)
y: np.ndarray = jnp.stack(
[
elegy.nn.Linear(2, name="component")(x)
for _ in range(self.k)
],
axis=1,
)
# equivalent to: y[..., 1] = 1.0 + jax.nn.elu(y[..., 1])
y = jax.ops.index_update(y, jax.ops.index[..., 1], 1.0 + jax.nn.elu(y[..., 1]))
logits = elegy.nn.Linear(self.k, name="gating")(x)
probs = jax.nn.softmax(logits, axis=-1)
return y, probs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment