Skip to content

Instantly share code, notes, and snippets.

@benvanik
Last active November 20, 2020 08:05
Show Gist options
  • Save benvanik/3af7ce12f8c901d0c115bde0345e540f to your computer and use it in GitHub Desktop.
Save benvanik/3af7ce12f8c901d0c115bde0345e540f to your computer and use it in GitHub Desktop.
IREE WebAssembly executable kernel manifesto

This is the shape of what we are talking about:

// Let's run a tile worth of work within a larger grid dispatch. That grid is
// defined *by us in the compiler* - it could, for example be a grid of 1x1x1
// such that this function is called once. Or, if there was benefit, you could
// make it go wider (like how ruy fans out work to a threadpool). And you can
// emit the code to choose the grid size/shape at runtime based on anything you
// want. That's what IREE gives you today. This here is your executable kernel
// equivalent to CUDA kernel or compute shader.
//
// Input/output buffers are provided and may be dynamically shaped. If they are
// you can get their shape at runtime. They could be partially shaped too (like
// memref<?x128x64xi8>), or have attached whole-program analysis or profile
// guided trace-derived known values ("this is always called with the outer
// dimension values of 1, 8, or ?") that would let you specialize however you
// wanted (or, not specialize at all if it didn't matter).
//
// %some_dynamic_value is anything you want passed in - it could be a value from
// the user all the way out at the top level (a flag indicating whether to use
// algorithm A or B) or a value from the runtime like the L1 cache size or
// microarchitecture of the current core this tile is executing on.
//
// The purpose of this tile is to use zero or more input buffers to fill one or
// more output buffers. That's it.
func @dispatch_tile(%input_buffer : memref<?x?xf32>,
                    %output_buffer : memref<?x?xf32>,
                    %some_dynamic_value : i32) {
  // Get the dynamic dimension values assuming both were totally unknown at
  // compile time. If we knew a dimension here this would fold away and allow
  // more compile-time optimization (loop unrolling decisions, loop
  // interleaving to reduce loop overhead, etc).
  %inner_loop_end = dim %output_buffer, 1 : memref<?x?xf32>
  %outer_loop_end = dim %output_buffer, 0 : memref<?x?xf32>
  for %i0 = 0 to %outer_loop_end {
    affine.for %i1 = 0 to %inner_loop_end {
      // Load a 128x16 tile from the input buffer into a vector size of our
      // inner tile dims. This could be any tile size we wanted, and if we
      // wanted to specialize for multiple of these we could do it by literally
      // putting an if statement (scf.if, or std.cond_br) around the outer loop
      // that switches at runtime.
      //
      // Note that the permutation map lets you change how the tiles are
      // indexed and here it's transposing (d0->d1 and d1->d0) but you can have
      // much more complex symbolic expressions (and if you properly use the
      // vector dialect most of those are done for you as part of work
      // fusion/distribution/tiling). The point is that you can express them.
      %input_tile = vector.transfer_read %input_buffer[%i0, %i1]
           {permutation_map: (d0, d1) -> (d1, d0)} :
           memref<?x?xf32>, vector<128x16xf32>

      // Here is where your actual ops happen.
      // They could be indexing ops (like vector.transpose, vector.scatter, etc), math ops (like vector.fma),
      // *or arbitrary std ops like std.exp and std.div*. They could even be calls to library functions that
      // are made at runtime that just get passed pointers to the IO.
      // The more that we know about the access patterns, though, the more we can fold away/distribute work
      // and efficiently tile things, but that's strictly an optimization: this would all function if both
      // the %inner_loop_end and %outer_loop_end were 1 (loops would go away) and the %input_tile and
      // %output_tile were the entire contents of the IO buffers. Of course, that's pretty silly as you are just
      // leaving performance on the table.
      %some_derived_value = std.div %i1, %some_dynamic_value : i32
      %output_tile = ruy.do_whatever_op %input_tile, %some_derived_value { arbitrary arguments and information }

      // Write the results back to the output buffer. Note that the permutation map allows you write them back in
      // any arrangement you wanted.
      vector.transfer_write %output_tile, %output_buffer[%c3, %c3]
           {permutation_map = (d0, d1) -> (d0, d1)}
           : vector<128x16xf32>, memref<?x?xvector<128x16xf32>>
    }

    // There's no reason you can't also perform ops here - it's just code. You
    // can have imperfectly nested loops, early-exit from loops, etc **it's
    // just code**
  }
  // Can also have other loops here, or whatever else - that's how you can implement a damn
  // convincing minecraft in a handful of kernels: https://www.shadertoy.com/view/wsByWV.
  return
}

When using this with a black-box opaque ruy.do_whatever_op you still get all the loop traversal/indexing/etc benefits and the ability to have multiple of those loops (and code motion across those loops), as well as the ability to insert any arbitrary ops you want. Or, you could restrict the whole set and say the only ops available at runtime are integer arithmetic for loop/flow control and there's no way to do any floating point/etc without going into a black-box op. You ahead-of-time lower from this higher level representation to your lower-level deployment ISA and at runtime you are JIT'ing some stupid simple code that is on the order of LuaJIT complexity. If you were only supporting these opaque ops for doing the actual mathy parts you wouldn't even be touching anything SIMD - it's just some integer arithmetic and branches. Like, 25 instructions, most of which are unary and binary arithmetic. This is what the IREE VM is doing, for example.

Conceptually you could go extremely lightweight here by turning this entire function into a single mondo fused op that took all the loop dimensions, dynamic runtime values, buffers, etc and did a call out to a single blackbox op. That is what this would become if the workgroup size was 1x1x1 and there was no inner tiling, just a single call. IREE had this support for awhile - you could denote that an executable was a special handwritten one that had a well-defined API. It's not a useful thing to do in 2020 though - we know it's possible and we know the limitations and we also know it does not save much time relative to doing it correctly. We aren't looking to build yet another way to implement tflite's 40+ argument LSTM op :)

As for how you may execute something like this there are many off the shelf solutions that - given some interchange format - let you interpret or JIT (or a mix of both). Like Cranelift which is lightweight but full-featured and a good example of something that could be used for this. Incidentally, it's developed to JIT webassembly. There's also things like wamr which is AOT/JIT/interpreter and runs on just about every arch (mips/aarch64/thumb/x86/etc) (this is my current favorite). Whether the expressions you want to store are textual MLIR, some hypothetical MLIR-adjacent binary representation, your own custom thing (like IREE's VM bytecode), or webassembly bytecode (which is cranelifts serialization format) doesn't matter. You are expressing a super simple fragment of a program. People have been doing this for 50 years. This is not new and there are no surprises, it's just engineering. It's possible to very very very poorly reimplement this all by using a protobuf or whatever with some fixed behavior, like, say most ML runtimes that exist today. Why? And this is me who went against all this advice and wrote the IREE VM saying that (and I want to unwind that :)

So now you have a way to express these things in an eminently optimizable form ahead-of-time, can express both arbitrarily general logic and black-box do-whatever-you-want ops of your choosing, can deploy that to runtime, and then at runtime turn that into native ISA instructions. You may then decide to either treat your blackbox ops as function calls you give to the JIT to insert as jump-and-link instructions (which is not hard) or plug into the JIT to insert your own expressions (also not hard) such that the JIT can save some registers for you (for loop stuff), or somewhere inbetween (emit your runtime-decided flow control expressions to the JIT but then in the loop/conditional body call in to your hand-written stuff).

The very strong and well-founded assertion here is that no step in this architecture is something you can skip and have anything but a throwaway toy. You need to start with an input program in MLIR (which may be linalg or some mix of vector ops and standard ops) so you're going to author a dialect and some conversions. You need to lower that as far as you can to some deployable format that you can decouple from your runtime (unless you want to be like TF, where graphdefs are forward compatible for often only a 3 week window in time), and it makes sense to do things that don't need to be done at runtime here like constant folding and CSE (why waste extra bytes on the wire/disk and extra time during JIT/execution when you could just not do that?). And then finally you need to load that deployment artifact and execute it at runtime. What doesn't matter to this architecture is at what level of granularity are the operations you are blackbox hand-writing (are they loop bodies? fragments of loop bodies? entire functions containing multiple loop bodies? etc), or your interchange format: seriously, you could lower it to brainfuck if you wanted and parse that at runtime - it's turning complete! You need all of these components to have something that works. There's no shortcut that lets you skip these layers.

The trick to engineering this for longevity and further extension is to choose the components so that you don't paint yourself into a corner. Do you need to invent your own interchange format? No, you don't. Use webassembly - it's fine, and then the code you own (and test and secure and such) at runtime is literally 20 lines of API calls to register your custom blackbox op functions and get function pointers: https://github.com/bytecodealliance/wasmtime/blob/main/cranelift/simplejit/examples/simplejit-minimal.rs#L48-L82 https://github.com/bytecodealliance/wasm-micro-runtime/blob/main/doc/embed_wamr.md#native-calls-wasm-functions-and-passes-parameters Let the browser vendors and hosted computing services deploying WASI do that hard work. 10 years from now some new shiny interchange format comes along? Just have a little WASM->that shim in your new runtimes and you can still load your 10 year old compiled artifacts - this is a solved engineering problem outside of ML :)

If you are targeting something wamr/cranelift/etc supports (x86/64 and aarch64 today) you're done. Have a beer. But the important thing is that now you have a choice: running on a microcontroller? Use an interpreter like wasm3 - you aren't running anything in it but some outer loop iterations and hoisted flow control ("are there < 1024 elements or > 1024 elements, choose the strategy based on that") and it doesn't matter if it supports SIMD: whether you are JITing or interpreting you as a blackbox op author is writing C functions to put in a function table for it to call. Worried about loop overhead? Tune the compiler to pick larger inner tile sizes such that you are literally executing dozens of loop instructions per iteration. This is a decision you'd have to make regardless of how you decided to serialize your intermediate representation to disk or how you decided to convert that into runnable code -- even if doing it via monkey-patching function pointers!

So what I'm saying here is that if you are going to build something that lets you go from something in the compiler to something at runtime, nothing about how you accomplish that actually matters to what you care about: the granularity and contents of your custom blackbox ops and their runtime implementations. So there is no reasonable option but to choose a path that allows you to skip by as much of this tax as you can and never have to touch it again. You use a standard MLIR lowering into webassembly (which today goes through LLVM-IR, but doesn't have to) that turns the lowered branch, arithmetic, and function call ops into webassembly expressions, serialize that to the webassembly bytecode format, and pick your choice of how you execute that. Write a little bit of C glue to register your functions and then get on with the interesting problems.

And again, you can't remove any of these steps and this kind of solution is the least work and has tons of benefits that lead to faster production deployment vs. rolling your own everything like security reviews/fuzzing on shared infra, many other people working on optimization and platform support, tools for debugging/diagnostics, etc. It's a side-benefit (that matters to me) that once in place it won't need to change - yes you may decide to adjust granularities of your blackbox ops, or add some MLIR conversion patterns that take your blockbox ops and decompose them into smaller compositions of blackbox ops in addition to vector/standard dialect ops that allows for better optimizations, or improve your handwritten runtime code, or integrate a new wasm runtime for a particular deployment scenario, but that is all work you would have to do anyway.

Finally, the engineering longevity comes in here not just from using a standard interchange with a ratified spec that will likely be supported as an input to some modernization pipeline for as long as humanity exists, but also from that it is a developing spec that will get more things over time (SIMD, additional SIMD instructions, etc). Missing something today? Blackbox call, just as you would do anyway. But as they start to land you can come and add tiny targeted MLIR conversion passes that turn your once blackbox op into other supported ops and start to remove the need for the blackbox op to exist at runtime at all. It's like how Windows still ships a thousand DLLs in C:\Windows\System32 even if a lot of them are just shims that call out into other code or there purely for legacy support - as someone deploying their app starts to run programs (nee models) that need fewer blackbox ops (or a different mix of new ones) they can continue to either drop dependencies/binary size or spend those bytes where it matters. The best solution is to scale up, though: always have an -O0 mode that lowers the vector dialect to scalar loops and will run (slowly) on a potato with a 1-2kLoC interpreter. Now you have 100% reach to any device that can run C code, and for models like screen brightness control or health that are dealing with dozens of integer activations that's all you need. By only adding the complexity and tax associated with runtime blackbox ops when you need it (for performance/power/etc for a particular use case by a particular user for a particular model) you are setting yourself up to do the least amount of work with the largest reach and being able to spend your time solving more interesting problems like what are those blackbox ops and how fast can you make them for what you care about.

@benvanik
Copy link
Author

benvanik commented Nov 19, 2020

As a follow-up, it's important to see the stages between source and execution as points on a spectrum that have some wiggle-room. The approach above has the lowering from the {vector dialect with arbitrary vector widths} -> {fixed-with vectors and black box ops} happen on the compiler side pre-deployment. There are other places to pin those if it makes sense that shift smaller amounts of compiler-ish code across the barrier between compile- and run-time.

For example, instead of lowering everything to something like webassembly (and even webassembly with simd), you could lower to pseudo ops (that's effectively what the blackbox ops above are). In the simplest case (above) those pseudo ops are just calls to well-known external functions that do the work. But when using structured serialization forms like SPIR-V or wasm it's possible to do efficient single-pass transformation on them without the need for a giant compiler infrastructure; think of each blackbox op as an intrinsic that a lightweight runtime linker knows about. Your linker could be dumb (or running in -O0) and take any call my.blackbox_op and turn that into a call to that function implemented in C, or it could be a bit smarter and replace call my.blackbox_op with any arbitrary sequence of instructions it wanted (or, just have a library of functions embedded that it calls and links across).

So you could emit something like this down to wasm:

func @builtin_vector.loopy_thing(@loop_body, ...) // extern at runtime
func @ruy.whatever ... // extern at runtime
func @loop_body(%i : index, ...) {
  std.div ...
  call ruy.whatever ...
}
func @entry_point() {
  call vector_intrinsics.loopy_thing %loop_params ... @loop_body
  ...
}

If the runtime had no specializations for loop magic vector_intrinsics at runtime it would just link in both a module providing your custom ruy.whatever for import from your native code as well as providing a stock vector_intrinsics.loopy_thing in wasm that was just a for-loop around a call to the body:

// (conceptual whole linkage, imagine wasm):
func @builtin_vector.loopy_thing(@loop_body, ...)  {
  // inlined wasm from the compiler
  for (...) { call @body(...) }
}
func @ruy.wahtever ... // still extern at runtime calling into C
func @loop_body(%i : index, ...) {
  std.div ...
  call ruy.whatever ...
}
func @entry_point() {
  call vector_intrinsics.loopy_thing %loop_params ... @loop_body
  ...
}

There's no need to add custom instructions to the ISA (new wasm instructions, etc). Any runtime that can execute this code can now do so forever by including a compatibility shim with that builtin - possibly even fetched on demand or bundled with the previous unchanged 10 year old module itself.

But if the runtime linker did want to do something clever here - like unrolling based on vector machine width, etc it has enough information to do so in the loop params, captures, etc. So instead of merging in the compatibility shim it can progressively enhance the whole thing by look for those calls and expanding them to it's own IR that then it can optimize across. So now the JIT (or a preprocessing migration/upgrade tool that happens once and is cached, etc) can how fully optimize that loopy_thing (inlining/hoisting/etc) in a very generic way. Old stuff will work on new stuff and new stuff can continue to work on old stuff (within reason).

And to re-iterate, this is not giant whole-program multi-MB binary rewriting - and almost more importantly it has zero work at all required on the runtime implementations. Want to bring up a new target platform/architecture/embedding scenario? Include the reference compatibility shims that turn all your fancy builtins or blackbox ops into simple loops and get things working in a day, not months. You now have the power to scale the same exact deployable artifacts from microcontrollers to servers and optimize for each where it matters, and do so timeshifted forever (effectively). Want to dynamically deploy a model to a 1 year old released app on user devices? (or, even a 1 month old released app?) You can do that now. Want to optimize an existing deployed model downloaded from something like tfhub for a new device 4 years after it was published? You can do that now.

None of this means that we can't also be making improvements to the whole stack throughout; it's all about decoupling the improvements we want to make from the day to day stuff that is not related to those improvements. It seems crazy that ML has backslid so much here when it's possible to run DirectX8 games on Windows (and Linux!) in 2020. This is all solved stuff. Even this particular approach is solved; it's why dynamic linkage of libraries into binaries can be a good thing (it can also be hell but it's all about tradeoffs and scope). The JIT style is also solved; this approach is how builtins in v8 work - the JS->IR frontend emits builtins, and the IR->MC backend can optionally (in many cases) choose to specialize those. Even though these are all deployed together the engineering benefits of decoupling the platform-agnostic frontend from the platform-specific backend makes porting, testing, and updating tractable.

Incidentally, all of this is also possible with SPIR-V instead of wasm. SPIR-V extensions are baked right into the binary specification to allow structured ways of doing this as well as runtime capability queries to take advantage of them. The scale of a wasm->wasm migration tool is the same as a spirv->spirv migration tool - a few hundred lines of code with no library dependencies. wasm is interesting because it scales down much lower on the spectrum (to microcontrollers) as in those cases you can have literally zero work happen at runtime and directly interpret the wasm whereas spirv may be a more natural form for representing vector dialect-style operations when you want to run a beefier set of passes (like GPU drivers do today).

Cooperative matrix in spirv is a good example of the exact kind of builtin/intrinsic we may want to support like this:
https://github.com/KhronosGroup/GLSL/blob/master/extensions/nv/GLSL_NV_cooperative_matrix.txt#L203-L324
It adds a bunch of "functions" (intrinsics) that the runtime compiler can use to lower from a high-fidelity/-information density form to a machine-specific implementation. But nothing about those functions can't also be done in base spirv itself with a loop and some index manipulation; the benefit here though is that there is no need for a complex raising/pattern matching/optimization process at runtime - it's a pure simple lowering - and that any driver could support these even if they didn't have any special hardware to execute them by just lowering them to their more primitive form. The critical bit is that information that enables significant hardware-specific runtime-decided optimization has been preserved from the compiler through to runtime, it's durable at-rest (you can load a SPIR-V binary using this extension from now till the end of time - even if you then just convert it to some other form), and it allows for progressive optimization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment