Skip to content

Instantly share code, notes, and snippets.

@shawwn
Last active August 6, 2021 07:36
Show Gist options
  • Save shawwn/e23b3166d20a8247f3513007005faa05 to your computer and use it in GitHub Desktop.
Save shawwn/e23b3166d20a8247f3513007005faa05 to your computer and use it in GitHub Desktop.

a conversation about the einops lib and why it's hard to optimize ML:

<PapuaHardyNet> the solution is obviously jax acc to shawn

which is why i'm thinking to also build a jax equivalent and compare all three: pytorch, pytorch + einops, jax. Let's see if I have the bandwidth to do so

<nshepperd2> einops is really cool imo

<nshepperd2> if there's a performance benefit, it's probably just from operation fusion

<shawwwn> that sounds like a great thing to do. you should.

you'll learn a lot. it's easy to hear "the bottlenecks are in unexpected places", but quite another thing to find them

<nshepperd2> like you can do something in one rearrange that would take a reshape, reduce, transpose, then another reshape

<PapuaHardyNet> all right, will prioritize this

<shawwwn> well, no. don't do it because it's important. do it because it's fun

<nshepperd2> jax.jit might well do this fusion itself

<shawwwn> it's actually XLA, not jax

jax directly spits out an XLA graph. But the graph is optimized by libtpu, which is where the real magic happens

that's why they only ship obfuscated binaries (apparently) for libtpu

<nshepperd2> but einops over jax would still be a superior ui to tensor.permute()

<shawwwn> honestly I'm a fan of encoding the names into the variables themselves

<shawwwn> https://github.com/shawwn/openai-server/blob/3e1351a82acd771aa648e49c6e874d82e891a61f/openai_server/gpt/jax/model.py#L237-L257

<feepbot> openai-server/model.py at jax-wip · shawwn/openai-server · GitHub (OpenAI API webserver. Contribute to shawwn/openai-server development by creating an account on GitHub.)

<shawwwn> I didn't expect to feel this way, but in my opinion this is the simplest, clearest self attention implementation I've ever seen

not for humans, but for "people who are just trying to figure out what the fuck the shape of any given variable is"

so for humans, I guess.

I also didn't expect to like einsum. I really don't -- it's still a complete mystery to me what W_bhtt = jnp.einsum("bthr,bThr->bhtT", Q_bthr, K_bthr) means

<PapuaHardyNet> that's because the variable naming convention shows the shape

<shawwwn> yes

the only thing I understand about that einsum is that it ends with -> bhtT

and the variable is named W_bhtt

so therefore ... uh ... it's doing ... things

but if I had to sit down and write out what the math ops are that it's doing, I wouldn't be able to

and that's dangerous, because this little einsum expression can really hide a lot of performance degredation that you wouldn't expect

<PapuaHardyNet> it can?

<shawwwn> mmhm.

as models get bigger, parallelization and sharding become crucial. But we're still living in the stone age in terms of visualization tools

your performance can be killed simply by sharding one way and not the other way, because the einsum multiplies that way

the clearest example of that is megatron, where without special considerations, a multiply like that might end up touching data on all of the cores at once, inside the innermost loop

so the GPT graph is basically, for i from 0 to n_blocks: block(...)

and each block is self-attention followed by mlp

so if your block ends up doing an einsum across all of the cores, you kill your performance, because you don't need to -- the cross replica traffic (the all-gather) can be deferred until the end of each lbock()

<PapuaHardyNet> which is why instead of einsum, you believe one should use more fine-grained operations?

<kuudes> I agree with shawwwn

<shawwwn> not necessarily

<nshepperd2> isn't that specific einsum just a naive matrix multiply

hard to see how it could be faster

<shawwwn> (nope. there's a transpose embedded in it)

<nshepperd2> transposes aren't really

<shawwwn> in my experience, the simplest code is the fastest code. hands down.

<nshepperd2> aren't real

<shawwwn> and I agree with nshepperd. it's very hard to figure out what's "real" when xla optimizations come into play

james bradbury keeps reassuring me that the XLA compiler is designed to optimize specifically the case that I just described

if that's true, then the way to make it fast is to make the code as simple as possible, with as few constraints as possible

because then the compiler will be able to bring all of its magic to bear, and you don't have to do anything

jax made headlines about ... 8 months ago? for breaking mlperf records

https://github.com/google-research/google-research/blob/d26ed752ca5f276f0c35a2d55e213b7cb2ea9339/flax_models/mlperf/transformer/models.py#L317-L319

as far as I can tell, this is the only thing they had to do

<feepbot> google-research/models.py at master · google-research/google-research · GitHub (Google Research. Contribute to google-research/google-research development by creating an account on GitHub.)

<shawwwn> it wasn't quite as easy as that, I'm sure. but the simplicity -- the lack of moving parts -- is essential

if you compare this to megatron's codebase in detail, you'll see that megatron was handcrafted for their exact design. whereas in this case, this code is the opposite; it's crafted to take advantage of the XLA optimizer, and to let it do all the heavy lifting

<+Robomot> image/png (256x256; 206 KB)

<nshepperd2> i wonder if using indefinite vs definite vs no article makes any actual difference to clip

<shawwwn> my first experience with this kind of magic was quite.a powerful experience:

https://gist.github.com/shawwn/2af039264ad0639ef43456e8a749b06e

<feepbot> subject: Model Parallelism is Awesome. GitHub Gist: instantly share code, notes, and snippets.

<shawwwn> that's what really convinced me that hand-crafting einsums and so on was hopeless in comparison

<nshepperd2> einsum seems quite superior for optimization to manually reshaping, transposing and reducing shot

<shawwwn> well, it gets even weirder.

I don't understand how we live in today's world. but we do live in this world:

<nshepperd2> since the einsum says exactly what you want to do, while the latter the optimizer needs to figure out what you meant and fuse the ops back together

<shawwwn> jax's design for xmap, is that xmap can turn axes into named tensors

however, their design is that named tensors have no order

if it has a name, it does not correspond to any integer index

it took a long time for me to admit to myself that maybe this isn't as crazy as it sounds

in other words, we're quickly traveling towards a future that looks an awful lot like "not worrying about transposing anything at all"

because if everything is named, then none of it has any ordering.

<shawwwn> at that point, the XLA compiler has complete freedom to do a lot of intensive optimizations that you otherwise couldn't do -- apparently. I don't understand why, quite yet.

but their track record speaks for itself, so I've learned to blindly trust that their designs are the result of people many times smarter than I am working together as an effective unit.

<shawwwn> so that's a longform explanation of why I don't really think that the einops library will ever be able to make more than a minor difference in performance. the magic is at a different layer entirely

understanding and exploiting this layer to the fullest, seems to be the real path. or at least the path I try to get close to.

<nshepperd2> well sure, you won't need it once all tensors are named and everyone uses TPUs and google finally enters the evil phase of ML market domination

presumably this will be after all nvidia engineers have been assassinated

<shawwwn> nshepperd2: yes. I think assassination is slated for TPU v5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment