Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save stellaraccident/ec1ab0f633cfca0a05866fd77705b4e4 to your computer and use it in GitHub Desktop.
Save stellaraccident/ec1ab0f633cfca0a05866fd77705b4e4 to your computer and use it in GitHub Desktop.

Just committed some updates to npcomp and getting excited... Full disclosure: it doesn't do anything yet, but the ideas are crystalizing for me.

It feels really good to get this idea out of my head. It's been bumping around for a while... A lot of the compilers I see for python-ey things have a really big hurdle you have to jump over to invoke compilation or lack a way to get enough constraints in place to get the optimizations that are important in a lot of cases (usually relying on fat annotations, fixed class hierarchies, etc). My idea is that we aren't really using the interactive power of python to do the program extraction... There shouldn't be one big "compile" method (or one trace annotation like @tf.function, etc). There should be a conversation with the system, just like there is in normal python programming. In that conversation, if I give it more, it should give me more.

So most of this stuff falls into the category of define-by-run or program-extraction (where we have a partially evaluated program in the interpreter and would like to extract a portion of it, potentially with tighter constraints than how it was originally defined). Why don't we just go more-define-by-run.

Example:

>>> def simple_mul(a, b):
...  return a + b

# Begin extraction of 'simple_mul'
>>> exp = npcomp.exporter()
>>> exp.simple_mul = simple_mul

# We know nothing about it.
>>> exp.simple_mul
pyfunc simple_mul(Any, Any) -> Any

# Manually constrain it (example - lots of sugar can be added)
>>> exp.simple_mul.a = 'NdArray'
>>> exp.simple_mul.a += Rank(3) + DType(np.float32)
>>> exp.simple_mul
pyfunc simple_mul(NdArray[Rank(2), DType(np.float32)], Any) -> Any

# Or automatically constrain it.
>>> exp.simple_mul(np.zeros((3,3)), np.ones((3,3)))
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]])
>>> exp.simple_mul
pyfunc simple_mul(NdArray[Shape(3, 3), DType(np.float32)], NdArray[Shape(3, 3), DType(np.float32)]) -> NdArray[Shape(3, 3), DType(np.float32)]

# Ok, now it is too constrained (static). Let's loosen it up.
# (There is magic here that will take some time to implement properly).
>>> exp.simple_mul.a += DynamicDim(0)
>>> exp.simple_mul
pyfunc simple_mul(NdArray[Shape(?, 3), DType(np.float32)], NdArray[Shape(3, 3), DType(np.float32)]) -> NdArray[Shape(?, 3), DType(np.float32)]

# Maybe do other things - it's a living artifact.
# More to this one, but if we fed it examples, it should have stats and
# additional constraints can be added in a similar way to how we constrained
# it in the first place. Significant magic still here.
>>> npcomp.quantize(exp.simple_mul)

# Good now. Compile it.
>>> compiled = npcomp.compile(exp)
>>> compiled
<Compiled Module of:
  func simple_mul(NdArray[Shape(?, 3), DType(np.float32)], NdArray[Shape(3, 3), DType(np.float32)]) -> NdArray[Shape(?, 3), DType(np.float32)]
>

# Run the compiled version.
>>> compiled.simple_mul([[1., 2., 3.]], [[4., 5., 6.]])
array([[ 4., 10., 18.]])

Right now, I'm starting to get the type system in place that allows this. And I have the bridge to MLIR setup. I'm going to start with doing simple tracing, but the approach should generalize to a full constrained-python program extraction. As always, full fidelity is a lot of work, but I expect some meaningful and simple cases are not too hard.

Note also that I annotate things with "pyfunc". I expect there is another side of this where we define MLIR intrinsics that can be used as normal and captured into the compiled artifact for whole program optimization. A lot of that stuff is existing-adjacent, so it is a matter of bridging properly.

I also haven't gotten into captures, constants, etc. There is a similar axis of progressive optimization that can be done to those if we follow a similar approach (just capture by default but let them be interactively optimized).

There are plenty of ways to generalize it too -- handling conflicting/overloaded constraint sets, capturing trial runs for guided optimization, etc. I also think that I can layer this such that even if I don't write the world's best tracer/compiler, we can plug outside implementations in and still use the general constraint and compiler framework (e.g. Jax traces, tf.function-esque, TorchScript, etc). So, in a sense, this is all the compiler frontend (with reference tracers/translators), but it is a different kind of compiler frontend than we've had before. In a lot of my model deployment work, I would have killed for something like this that let me interactively refine my way to good, high performance, compiled artifacts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment