Skip to content

Instantly share code, notes, and snippets.

@cgarciae
Last active July 8, 2020 01:04
Show Gist options
  • Save cgarciae/e003bb36a5af0c4ef315c9fb2da85e11 to your computer and use it in GitHub Desktop.
Save cgarciae/e003bb36a5af0c4ef315c9fb2da85e11 to your computer and use it in GitHub Desktop.
Simple cumulative accuracy metric using Haiku and Jax.
class Accuracy(hk.Module):
def __init__(self, y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:
total = hk.get_state(
"total", shape=[], dtype=jnp.float32, init=hk.initializers.Constant(0)
)
count = hk.get_state(
"count", shape=[], dtype=jnp.int64, init=hk.initializers.Constant(0)
)
total += jnp.mean(y_true == jnp.argmax(y_pred, axis=-1))
count += jnp.prod(y_true.shape)
hk.set_state("total", total)
hk.set_state("count", count)
return total / count
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment