I had this gross reshape/tensor product/transpose stuff on huge matrices, and I knew it was making
intermediate copies of the matrices that I didn't want to. So I tried out np.einsum
, and
I think it actually turned out simpler than thinking through the other matrix manipulation.
Here are some quick notes.
This post is great: https://stackoverflow.com/a/33641428 (or http://ajcr.net/Basic-guide-to-einsum/). The docs are useful after getting comfy with it. https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.einsum.html
np.einsum
looks pretty scary. Matrix multiplication becomes:
np.einsum('ij,jk->ik', A, B)
I'll keep using that example.
ij,jk->ik
tells einsum
what it should do ("einstein sum subscripts string").
The other arguments are the arrays it should act on ("operands").
ij,jk->ik
is like defining a little function array1, array2 -> output
.
Each letter labels an axis. ij
is labeling the two axes of A.
I can read ij,jk->ik
as "takes a 2D matrix, another 2D matrix, and returns a third 2D matrix."
Then there are the rules:
- repeating a letter in the arguments means to multiply along those axis (http://ajcr.net/Basic-guide-to-einsum/ walks through it)
- omitting a letter from the right-hand side means sum over this axis.
- the order of the letters in the output is the order of the array, so I can transpose too.
tbh, what ended up working best was not thinking too hard, labeling my two input axes and my output axes
and following the rules to update it
("the 4th axis is the x
dimension in A, and the 2nd in B, and I know I
want to multiply them together". "the 4th dimension of the output should be of shape x
").
If I'm not missing something, once I got over the notation, it turned out simpler to work through than thinking about reshaping and doing tensor products!
When I was first messing with it, I kept getting discouraging errors. Unfortunately I didn't write them down. So instead I made some changes and saw which errors I got.
i,jk->ik
ValueError: operand has more dimensions than subscripts given in einstein sum, but no '...' ellipsis provided to broadcast the extra dimensions.
jki,jk->ik
ValueError: einstein sum subscripts string contains too many subscripts for operand 0
ij,jk->im
ValueError: einstein sum subscripts string included output subscript 'm' which never appeared in an input
jk->ik
ValueError: fewer operands provided to einstein sum function than specified in the subscripts string
I think this happens when the shapes of the axes are wrong, and when I've got a axes mislabeled.
ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (3,2,2,2,3,3)->(3,newaxis,2,2,3,3,2) (2,3,2,2)->(2,newaxis,newaxis,2,2,3)
very helpful, thank you!