Created
May 20, 2020 04:09
-
-
Save niloch/d414e9e647fa24f7bee2be9ce2eeceeb to your computer and use it in GitHub Desktop.
Jax implementation of github.com/srush/parallax using [pydantic](https://pydantic-docs.helpmanual.io/) instead of dataclasses
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Abstract Base class with automatic Pytree registration | |
Inspired by from https://github.com/google/jax/issues/2916 | |
""" | |
import json | |
from abc import abstractmethod | |
from typing import Any, Dict, List, Tuple | |
import jax.numpy as np | |
from jax import jit, vmap, random | |
from jax.interpreters.xla import DeviceArray | |
from jax.nn.initializers import normal | |
from jax.numpy import ndarray as Tensor | |
from jax.tree_util import register_pytree_node | |
from pydantic import BaseModel | |
class Module(BaseModel): | |
"""Abstract Class for Modules. Enforces subclasses to implement | |
__call__. Regitered as Pytree. Optional runtime Validation of field types. | |
Immutable. Simple json pretty print. | |
OpeanAPI Schema generation. | |
""" | |
class Config: | |
allow_mutation = False | |
arbitrary_types_allowed = True | |
json_encoders = {DeviceArray: lambda t: f"shape={t.shape}"} | |
def __init_subclass__(cls, is_abstract: bool = False, **kwargs: Any) -> None: | |
super().__init_subclass__(**kwargs) # type: ignore | |
if not is_abstract: | |
register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) | |
def __repr__(self) -> str: | |
return json.dumps({self.__class__.__name__: json.loads(self.json())}, indent=4) | |
def dict(self, **kwargs) -> Dict[str, Any]: # type: ignore | |
if len(self.__fields__) == 0: | |
return {self.__class__.__name__: "activation"} | |
return super().dict(**kwargs) | |
@abstractmethod | |
def __call__(self, inputs: Tensor, **kwargs: Any) -> Tensor: | |
raise NotImplementedError | |
def tree_flatten(self) -> Tuple[List, None]: | |
attributes = list(self.dict().items()) | |
return (attributes, None) | |
@classmethod | |
def tree_unflatten(cls, aux: Any, params: List[Any]) -> "Module": | |
mapping = {key: val for key, val in params} | |
# Disable validation from unflattening for speed up | |
return cls.construct(**mapping) | |
class Linear(Module): | |
"""Dense Linear Layer. | |
Computes output = np.dot(w, inputs) + b""" | |
w: Tensor | |
b: Tensor | |
@jit | |
def __call__(self, inputs: Tensor, **kwargs: Any) -> Tensor: | |
"""outputs = np.dot(w, inputs) + b in single instance notation.""" | |
return np.dot(self.w, inputs) + self.b | |
@classmethod | |
def initialize(cls, *, input_size: int, output_size: int, key: Tensor,) -> "Linear": | |
"""Factory for new Linear from input and output dimentsions""" | |
return cls( | |
w=normal(stddev=1.0)(key, shape=(output_size, input_size)), | |
b=np.zeros(shape=(output_size,)), | |
) | |
class Tanh(Module): | |
@jit | |
def __call__(self, inputs: Tensor, **kwargs: Any) -> Tensor: | |
return np.tanh(inputs) | |
class MLP(Module): | |
layers = List[Module] | |
@jit | |
def predict(self, single_input: Tensor, key: Tensor = None) -> Tensor: | |
"""Predict for a single instance by iterating over all the layers""" | |
for layer in self.layers: # type: ignore | |
single_input = layer(single_input, key=key) | |
return single_input | |
@jit | |
def __call__(self, batched_inputs: Tensor, batched_keys: Tensor = None) -> Tensor: | |
"""Batched Predictions""" | |
return vmap(self.predict)(batched_inputs, batched_keys) | |
@classmethod | |
def create_mlp( | |
cls, | |
input_dim: int, | |
hidden_dim: int, | |
output_dim: int, | |
num_hidden: int, | |
key: Tensor, | |
) -> "MLP": | |
key, subkey = random.split(key) | |
layers: List[Module] = [ | |
Linear.initialize( | |
input_size=input_dim, output_size=hidden_dim, key=subkey, | |
), | |
Tanh(), | |
] | |
for _ in range(num_hidden - 2): | |
key, subkey = random.split(key) | |
layers.append( | |
Linear.initialize( | |
input_size=hidden_dim, output_size=hidden_dim, key=subkey, | |
) | |
) | |
layers.append(Tanh()) | |
key, subkey = random.split(key) | |
layers.append( | |
Linear.initialize( | |
input_size=hidden_dim, output_size=output_dim, key=subkey, | |
) | |
) | |
# Must use unvalidated contructor, otherwise throws an error | |
return cls.construct(layers=layers) | |
key = random.PRNGKey(42) | |
mlp = MLP.create_mlp(input_dim=10, hidden_dim=50, output_dim=4, num_hidden=5, key=key) | |
print(mlp) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment