a conversation about the einops lib and why it's hard to optimize ML:
<PapuaHardyNet> the solution is obviously jax acc to shawn
which is why i'm thinking to also build a jax equivalent and compare all three: pytorch, pytorch + einops, jax. Let's see if I have the bandwidth to do so
<nshepperd2> einops is really cool imo
<nshepperd2> if there's a performance benefit, it's probably just from operation fusion
<shawwwn> that sounds like a great thing to do. you should.
you'll learn a lot. it's easy to hear "the bottlenecks are in unexpected places", but quite another thing to find them
<nshepperd2> like you can do something in one rearrange that would take a reshape, reduce, transpose, then another reshape
<PapuaHardyNet> all right, will prioritize this
<shawwwn> well, no. don't do it because it's important. do it because it's fun
<nshepperd2> jax.jit might well do this fusion itself
<shawwwn> it's actually XLA, not jax
jax directly spits out an XLA graph. But the graph is optimized by libtpu, which is where the real magic happens
that's why they only ship obfuscated binaries (apparently) for libtpu
<nshepperd2> but einops over jax would still be a superior ui to tensor.permute()
<shawwwn> honestly I'm a fan of encoding the names into the variables themselves
<feepbot> openai-server/model.py at jax-wip · shawwn/openai-server · GitHub (OpenAI API webserver. Contribute to shawwn/openai-server development by creating an account on GitHub.)
<shawwwn> I didn't expect to feel this way, but in my opinion this is the simplest, clearest self attention implementation I've ever seen
not for humans, but for "people who are just trying to figure out what the fuck the shape of any given variable is"
so for humans, I guess.
I also didn't expect to like einsum. I really don't -- it's still a complete mystery to me what W_bhtt = jnp.einsum("bthr,bThr->bhtT", Q_bthr, K_bthr) means
<PapuaHardyNet> that's because the variable naming convention shows the shape
<shawwwn> yes
the only thing I understand about that einsum is that it ends with -> bhtT
and the variable is named W_bhtt
so therefore ... uh ... it's doing ... things
but if I had to sit down and write out what the math ops are that it's doing, I wouldn't be able to
and that's dangerous, because this little einsum expression can really hide a lot of performance degredation that you wouldn't expect
<PapuaHardyNet> it can?
<shawwwn> mmhm.
as models get bigger, parallelization and sharding become crucial. But we're still living in the stone age in terms of visualization tools
your performance can be killed simply by sharding one way and not the other way, because the einsum multiplies that way
the clearest example of that is megatron, where without special considerations, a multiply like that might end up touching data on all of the cores at once, inside the innermost loop
so the GPT graph is basically, for i from 0 to n_blocks: block(...)
and each block is self-attention followed by mlp
so if your block ends up doing an einsum across all of the cores, you kill your performance, because you don't need to -- the cross replica traffic (the all-gather) can be deferred until the end of each lbock()
<PapuaHardyNet> which is why instead of einsum, you believe one should use more fine-grained operations?
<kuudes> I agree with shawwwn
<shawwwn> not necessarily
<nshepperd2> isn't that specific einsum just a naive matrix multiply
hard to see how it could be faster
<shawwwn> (nope. there's a transpose embedded in it)
<nshepperd2> transposes aren't really
<shawwwn> in my experience, the simplest code is the fastest code. hands down.
<nshepperd2> aren't real
<shawwwn> and I agree with nshepperd. it's very hard to figure out what's "real" when xla optimizations come into play
james bradbury keeps reassuring me that the XLA compiler is designed to optimize specifically the case that I just described
if that's true, then the way to make it fast is to make the code as simple as possible, with as few constraints as possible
because then the compiler will be able to bring all of its magic to bear, and you don't have to do anything
jax made headlines about ... 8 months ago? for breaking mlperf records
as far as I can tell, this is the only thing they had to do
<feepbot> google-research/models.py at master · google-research/google-research · GitHub (Google Research. Contribute to google-research/google-research development by creating an account on GitHub.)
<shawwwn> it wasn't quite as easy as that, I'm sure. but the simplicity -- the lack of moving parts -- is essential
if you compare this to megatron's codebase in detail, you'll see that megatron was handcrafted for their exact design. whereas in this case, this code is the opposite; it's crafted to take advantage of the XLA optimizer, and to let it do all the heavy lifting
<+Robomot> image/png (256x256; 206 KB)
<nshepperd2> i wonder if using indefinite vs definite vs no article makes any actual difference to clip
<shawwwn> my first experience with this kind of magic was quite.a powerful experience:
https://gist.github.com/shawwn/2af039264ad0639ef43456e8a749b06e
<feepbot> subject: Model Parallelism is Awesome. GitHub Gist: instantly share code, notes, and snippets.
<shawwwn> that's what really convinced me that hand-crafting einsums and so on was hopeless in comparison
<nshepperd2> einsum seems quite superior for optimization to manually reshaping, transposing and reducing shot
<shawwwn> well, it gets even weirder.
I don't understand how we live in today's world. but we do live in this world:
<nshepperd2> since the einsum says exactly what you want to do, while the latter the optimizer needs to figure out what you meant and fuse the ops back together
<shawwwn> jax's design for xmap, is that xmap can turn axes into named tensors
however, their design is that named tensors have no order
if it has a name, it does not correspond to any integer index
it took a long time for me to admit to myself that maybe this isn't as crazy as it sounds
in other words, we're quickly traveling towards a future that looks an awful lot like "not worrying about transposing anything at all"
because if everything is named, then none of it has any ordering.
<shawwwn> at that point, the XLA compiler has complete freedom to do a lot of intensive optimizations that you otherwise couldn't do -- apparently. I don't understand why, quite yet.
but their track record speaks for itself, so I've learned to blindly trust that their designs are the result of people many times smarter than I am working together as an effective unit.
<shawwwn> so that's a longform explanation of why I don't really think that the einops library will ever be able to make more than a minor difference in performance. the magic is at a different layer entirely
understanding and exploiting this layer to the fullest, seems to be the real path. or at least the path I try to get close to.
<nshepperd2> well sure, you won't need it once all tensors are named and everyone uses TPUs and google finally enters the evil phase of ML market domination
presumably this will be after all nvidia engineers have been assassinated
<shawwwn> nshepperd2: yes. I think assassination is slated for TPU v5