Created
October 26, 2021 07:19
-
-
Save mattjj/473c5fc1f08ac704b26b6dce42a7682b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# referenced @chhillee's https://github.com/pytorch/functorch/blob/main/functorch/_src/nnc_compile.py | |
from typing import Callable, Dict, Any, List | |
from functools import partial | |
import numpy as np | |
import torch | |
import torch._C._te as te | |
from jax import core | |
from jax import linear_util as lu | |
from jax.interpreters import partial_eval as pe | |
from jax.interpreters import xla | |
from jax.tree_util import tree_flatten, tree_unflatten | |
from jax._src.api_util import flatten_fun | |
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map, | |
safe_zip) | |
map, unsafe_map = safe_map, map | |
zip, unsafe_zip = safe_zip, zip | |
def pytorch_jit(f: Callable): | |
def f_jit(*args, **kwargs): | |
args_flat, in_tree = tree_flatten((args, kwargs)) | |
flat_f, out_tree = flatten_fun(lu.wrap_init(f), in_tree) | |
out_flat = pytorch_jit_p.bind(flat_f, *args_flat) | |
return tree_unflatten(out_tree(), out_flat) | |
return f_jit | |
pytorch_jit_p = core.CallPrimitive('pytorch_jit') | |
@pytorch_jit_p.def_impl | |
def pytorch_jit_impl(f: lu.WrappedFun, *args): | |
# trace | |
in_avals = map(xla.abstractify, args) | |
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals) | |
if consts: raise NotImplementedError | |
# compile | |
inputs = [parameter(i, a) for i, a in enumerate(in_avals)] | |
outputs, out_stmts = nnc_eval_jaxpr(jaxpr, inputs) | |
loopnest = te.LoopNest(te.Stmt(out_stmts), outputs) | |
loopnest.simplify() | |
loopnest.prepare_for_codegen() | |
stmt = te.simplify(loopnest.root_stmt()) | |
cg = te.construct_codegen('llvm', stmt, [*inputs, *outputs]) | |
# execute | |
ins = map(torch.from_numpy, map(np.array, args)) | |
outs = map(empty_like, out_avals) | |
cg.call([*ins, *outs]) | |
return map(np.asarray, outs) | |
def parameter(idx: int, aval: core.ShapedArray) -> te.ExprHandle: | |
return te.BufHandle(f'in_{idx}', shape_from_aval(aval), dtype_from_aval(aval)) | |
def shape_from_aval(aval: core.ShapedArray) -> List[te.ExprHandle]: | |
return map(te.ExprHandle.int, aval.shape) | |
def dtype_from_aval(aval: core.ShapedArray) -> te.Dtype: | |
table = {'float32': te.Dtype.Float, 'int32': te.Dtype.Int, | |
'bool': te.Dtype.Bool} | |
return table[aval.dtype.name] | |
def empty_like(aval: core.ShapedArray) -> torch.Tensor: | |
table = {'float32': torch.float32, 'int32': torch.int32, | |
'bool': torch.bool} | |
return torch.empty(aval.shape, dtype=table[aval.dtype.name]) | |
def literal(aval: core.ShapedArray, val: Any) -> te.ExprHandle: | |
if aval.dtype == np.dtype('float32'): | |
return te.ExprHandle.float(val) | |
elif aval.dtype == np.dtype('int32'): | |
return te.ExprHandle.int(val) | |
elif aval.dtype == np.dtype('bool'): | |
return te.ExprHandle.bool(val) | |
else: | |
raise NotImplementedError(f'literal: {val}:{aval}') | |
def nnc_eval_jaxpr(jaxpr: core.Jaxpr, args): | |
env: Dict[core.Var, te.ExprHandle] = {} | |
stmts: List[te.Stmt] = [] | |
def read(x: core.Atom) -> te.ExprHandle: | |
if type(x) is core.Literal: | |
return literal(x.aval, x.val) | |
else: | |
return env[x] | |
def write(v: core.Var, expr: te.ExprHandle) -> None: | |
env[v] = expr | |
map(write, jaxpr.invars, args) | |
for eqn in jaxpr.eqns: | |
in_avals = [x.aval for x in eqn.invars] | |
out_avals = [v.aval for v in eqn.outvars] | |
in_exprs = map(read, eqn.invars) | |
rule = translations[eqn.primitive] | |
out_exprs, out_stmts = rule(in_avals, out_avals, in_exprs, **eqn.params) | |
stmts.extend(out_stmts) | |
map(write, eqn.outvars, out_exprs) | |
out_exprs = map(read, jaxpr.outvars) | |
return out_exprs, out_stmts | |
translations = {} | |
### | |
from jax._src.lax import lax | |
def standard_lowering(name: str): | |
name = f'aten::{name}' | |
def lower(in_avals, out_avals, in_exprs): | |
del in_avals | |
aval, = out_avals | |
out = te.lower(name, in_exprs, shape_from_aval(aval), dtype_from_aval(aval)) | |
return [out.buf()], [out.stmt()] | |
return lower | |
translations[lax.sin_p] = standard_lowering('sin') | |
translations[lax.mul_p] = standard_lowering('mul') | |
translations[lax.cos_p] = standard_lowering('cos') # are these names right? | |
### | |
from jax import grad | |
import jax.numpy as jnp | |
x = jnp.array([1., 2., 3.]) | |
y = pytorch_jit(jnp.sin)(x) | |
print(y) | |
print(jnp.sin(x)) | |
x = jnp.array([1., 2., 3.]) | |
y = pytorch_jit(lambda x: x * x)(x) | |
print(y) | |
print(x * x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment