Skip to content

Instantly share code, notes, and snippets.

@niloch
Created May 20, 2020 04:09
Show Gist options
  • Save niloch/d414e9e647fa24f7bee2be9ce2eeceeb to your computer and use it in GitHub Desktop.
Save niloch/d414e9e647fa24f7bee2be9ce2eeceeb to your computer and use it in GitHub Desktop.
Jax implementation of github.com/srush/parallax using [pydantic](https://pydantic-docs.helpmanual.io/) instead of dataclasses
"""
Abstract Base class with automatic Pytree registration
Inspired by from https://github.com/google/jax/issues/2916
"""
import json
from abc import abstractmethod
from typing import Any, Dict, List, Tuple
import jax.numpy as np
from jax import jit, vmap, random
from jax.interpreters.xla import DeviceArray
from jax.nn.initializers import normal
from jax.numpy import ndarray as Tensor
from jax.tree_util import register_pytree_node
from pydantic import BaseModel
class Module(BaseModel):
"""Abstract Class for Modules. Enforces subclasses to implement
__call__. Regitered as Pytree. Optional runtime Validation of field types.
Immutable. Simple json pretty print.
OpeanAPI Schema generation.
"""
class Config:
allow_mutation = False
arbitrary_types_allowed = True
json_encoders = {DeviceArray: lambda t: f"shape={t.shape}"}
def __init_subclass__(cls, is_abstract: bool = False, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs) # type: ignore
if not is_abstract:
register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten)
def __repr__(self) -> str:
return json.dumps({self.__class__.__name__: json.loads(self.json())}, indent=4)
def dict(self, **kwargs) -> Dict[str, Any]: # type: ignore
if len(self.__fields__) == 0:
return {self.__class__.__name__: "activation"}
return super().dict(**kwargs)
@abstractmethod
def __call__(self, inputs: Tensor, **kwargs: Any) -> Tensor:
raise NotImplementedError
def tree_flatten(self) -> Tuple[List, None]:
attributes = list(self.dict().items())
return (attributes, None)
@classmethod
def tree_unflatten(cls, aux: Any, params: List[Any]) -> "Module":
mapping = {key: val for key, val in params}
# Disable validation from unflattening for speed up
return cls.construct(**mapping)
class Linear(Module):
"""Dense Linear Layer.
Computes output = np.dot(w, inputs) + b"""
w: Tensor
b: Tensor
@jit
def __call__(self, inputs: Tensor, **kwargs: Any) -> Tensor:
"""outputs = np.dot(w, inputs) + b in single instance notation."""
return np.dot(self.w, inputs) + self.b
@classmethod
def initialize(cls, *, input_size: int, output_size: int, key: Tensor,) -> "Linear":
"""Factory for new Linear from input and output dimentsions"""
return cls(
w=normal(stddev=1.0)(key, shape=(output_size, input_size)),
b=np.zeros(shape=(output_size,)),
)
class Tanh(Module):
@jit
def __call__(self, inputs: Tensor, **kwargs: Any) -> Tensor:
return np.tanh(inputs)
class MLP(Module):
layers = List[Module]
@jit
def predict(self, single_input: Tensor, key: Tensor = None) -> Tensor:
"""Predict for a single instance by iterating over all the layers"""
for layer in self.layers: # type: ignore
single_input = layer(single_input, key=key)
return single_input
@jit
def __call__(self, batched_inputs: Tensor, batched_keys: Tensor = None) -> Tensor:
"""Batched Predictions"""
return vmap(self.predict)(batched_inputs, batched_keys)
@classmethod
def create_mlp(
cls,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_hidden: int,
key: Tensor,
) -> "MLP":
key, subkey = random.split(key)
layers: List[Module] = [
Linear.initialize(
input_size=input_dim, output_size=hidden_dim, key=subkey,
),
Tanh(),
]
for _ in range(num_hidden - 2):
key, subkey = random.split(key)
layers.append(
Linear.initialize(
input_size=hidden_dim, output_size=hidden_dim, key=subkey,
)
)
layers.append(Tanh())
key, subkey = random.split(key)
layers.append(
Linear.initialize(
input_size=hidden_dim, output_size=output_dim, key=subkey,
)
)
# Must use unvalidated contructor, otherwise throws an error
return cls.construct(layers=layers)
key = random.PRNGKey(42)
mlp = MLP.create_mlp(input_dim=10, hidden_dim=50, output_dim=4, num_hidden=5, key=key)
print(mlp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment