Skip to content

Instantly share code, notes, and snippets.

@jayendra13
Created January 4, 2022 17:30
Show Gist options
  • Save jayendra13/af1d177629b5d80188df298f25377022 to your computer and use it in GitHub Desktop.
Save jayendra13/af1d177629b5d80188df298f25377022 to your computer and use it in GitHub Desktop.
import gin
import jax
import jax.numpy as jnp
import haiku as hk
@gin.configurable
class Model(hk.Module):
def __init__(self, num_classes, name=None):
super().__init__(name=name)
self.num_classes = num_classes
def __call__(self, x):
return hk.Linear(self.num_classes)(x)
"""Content of the gin file
Model.num_classes = 12
"""
gin.parse_config_file("config.gin")
model = hk.without_apply_rng(hk.transform(lambda *args: Model()(*args)))
# print(model.num_classes)
print(dir(model))
rng = jax.random.PRNGKey(42)
x = jnp.ones((4,5))
params = model.init(rng, x)
print(params)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment