Created
February 3, 2019 23:36
-
-
Save eigenfoo/673063880decd9f41009b6054bd77e7f to your computer and use it in GitHub Desktop.
Example of how Einstein notation simplifies tensor manipulations.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/python | |
from time import time | |
import torch | |
batch_size = 128 | |
image_width = 64 | |
image_height = 64 | |
num_channels = 3 # RGB, for instance. | |
# Suppose we wanted to scale each channel of each image by a certain factor, and | |
# add them together and the end. | |
image_channels = torch.randn(batch_size, image_width, image_height, num_channels) | |
scale_factors = torch.randn(batch_size, num_channels) | |
# With einsum, this is straightforward, once you get the hang of thinking in | |
# terms of subscripts. | |
start = time() | |
images = torch.einsum("ijkl,il->ijk", image_channels, scale_factors) | |
print(time() - start) # ~ 0.002s | |
# Alternatively, we could pointwise multiply (Hadamard product) with | |
# broadcasting. Notice the unsqueezing, and how we have to sum separately. | |
start = time() | |
unsqueezed_scale_factors = scale_factors.unsqueeze(1).unsqueeze(2) | |
images_2 = torch.sum(image_channels * unsqueezed_scale_factors, dim=3) | |
print(time() - start) # ~ 0.010s | |
assert torch.all(torch.eq(images, images_2)) | |
# Finally, we could do a batched matmul. The matmul sums for us, but we still | |
# need some squeezing/unsqueezing. | |
start = time() | |
unsqueezed_scale_factors = scale_factors.unsqueeze(1).unsqueeze(3) | |
images_3 = torch.squeeze(torch.matmul(image_channels, unsqueezed_scale_factors)) | |
print(time() - start) # ~ 0.005s | |
assert torch.all(torch.eq(images, images_3)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment