Skip to content

Instantly share code, notes, and snippets.

@jessstringham
Last active October 12, 2022 14:17
Show Gist options
  • Save jessstringham/5483028423c350d7b771d5c0482be246 to your computer and use it in GitHub Desktop.
Save jessstringham/5483028423c350d7b771d5c0482be246 to your computer and use it in GitHub Desktop.

np.einsum

I had this gross reshape/tensor product/transpose stuff on huge matrices, and I knew it was making intermediate copies of the matrices that I didn't want to. So I tried out np.einsum, and I think it actually turned out simpler than thinking through the other matrix manipulation.

Here are some quick notes.

Real blogs/documentation

This post is great: https://stackoverflow.com/a/33641428 (or http://ajcr.net/Basic-guide-to-einsum/). The docs are useful after getting comfy with it. https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.einsum.html

My notes

np.einsum looks pretty scary. Matrix multiplication becomes:

np.einsum('ij,jk->ik', A, B)

I'll keep using that example.

ij,jk->ik tells einsum what it should do ("einstein sum subscripts string"). The other arguments are the arrays it should act on ("operands").

ij,jk->ik is like defining a little function array1, array2 -> output.

Each letter labels an axis. ij is labeling the two axes of A.

I can read ij,jk->ik as "takes a 2D matrix, another 2D matrix, and returns a third 2D matrix."

Then there are the rules:

  • repeating a letter in the arguments means to multiply along those axis (http://ajcr.net/Basic-guide-to-einsum/ walks through it)
  • omitting a letter from the right-hand side means sum over this axis.
  • the order of the letters in the output is the order of the array, so I can transpose too.

tbh, what ended up working best was not thinking too hard, labeling my two input axes and my output axes and following the rules to update it ("the 4th axis is the x dimension in A, and the 2nd in B, and I know I want to multiply them together". "the 4th dimension of the output should be of shape x").

If I'm not missing something, once I got over the notation, it turned out simpler to work through than thinking about reshaping and doing tensor products!

Trouble shooting

When I was first messing with it, I kept getting discouraging errors. Unfortunately I didn't write them down. So instead I made some changes and saw which errors I got.

Dropping a label

i,jk->ik

ValueError: operand has more dimensions than subscripts given in einstein sum, but no '...' ellipsis provided to broadcast the extra dimensions.

Too many labels

jki,jk->ik

ValueError: einstein sum subscripts string contains too many subscripts for operand 0

New label on right

ij,jk->im

ValueError: einstein sum subscripts string included output subscript 'm' which never appeared in an input

Not enough arrays

jk->ik

ValueError: fewer operands provided to einstein sum function than specified in the subscripts string

Mixed up scripts

I think this happens when the shapes of the axes are wrong, and when I've got a axes mislabeled.

ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (3,2,2,2,3,3)->(3,newaxis,2,2,3,3,2) (2,3,2,2)->(2,newaxis,newaxis,2,2,3) 
@slvrfn
Copy link

slvrfn commented Feb 19, 2021

very helpful, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment