Created
March 1, 2024 23:35
-
-
Save polvalente/b3701372ab9c22ebc7dab4fdc1433ded to your computer and use it in GitHub Desktop.
Equivalência de tensordot para einsum
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
In [1]: import torch | |
In [2]: a = torch.arange(3, 4, 5) | |
In [3]: a = torch.arange(60).reshape(3, 4, 5) | |
In [4]: b = torch.arange(24).reshape(1, 4, 3, 2) | |
# Aqui, (junto com o output da linha Out[7]) a gente vê que | |
# que o tensordot usando contraction axes específicas tem | |
# esse resultado. | |
# Na prática, a gente pode pensar que o cálculo é equivalente | |
# a transpor esses eixos pro final dos respectivos tensores, | |
# e "flatten" eles em um grande vetor, aí rola soma dos produtos ponto a ponto normalmente. | |
In [5]: torch.tensordot(a, b, dims=((0, 1), (2, 1))) | |
Out[5]: | |
tensor([[[4400, 4730]], | |
[[4532, 4874]], | |
[[4664, 5018]], | |
[[4796, 5162]], | |
[[4928, 5306]]]) | |
In [7]: torch.tensordot(a, b, dims=((0, 1), (2, 1))).shape | |
Out[7]: torch.Size([5, 1, 2]) | |
# Essa expressão de einsum encoda a mesma conta: | |
# a dimensão 0 da esquerda corresponde à dim -2 da direita, | |
# a dim 1 da esquerda, à -3 da direita, e as outras são encaixadas, | |
# na ordem em que aparecem, como leading axes do resultado. | |
In [9]: torch.einsum('ijk,...jiw->k...w', a, b).shape | |
Out[9]: torch.Size([5, 1, 2]) | |
In [10]: torch.einsum('ijk,...jiw->k...w', a, b) | |
Out[10]: | |
tensor([[[4400, 4730]], | |
[[4532, 4874]], | |
[[4664, 5018]], | |
[[4796, 5162]], | |
[[4928, 5306]]]) | |
# abaixo, o desenvolvimento mais intuitivo de o que as duas operações estão fazendo por debaixo dos panos: | |
In [17]: a_flat = torch.permute(a, (2, 0, 1)).reshape((a.shape[2], -1)) | |
In [18]: a_flat.shape | |
Out[18]: torch.Size([5, 12]) | |
In [21]: b_flat = torch.permute(b, (0, 3, 2, 1)).reshape((b.shape[0], b.shape[3], -1)) | |
In [22]: b_flat.shape | |
Out[22]: torch.Size([1, 2, 12]) | |
In [23]: torch.tensordot(a_flat, b_flat, dims=((-1,), (-1,))) | |
Out[23]: | |
tensor([[[4400, 4730]], | |
[[4532, 4874]], | |
[[4664, 5018]], | |
[[4796, 5162]], | |
[[4928, 5306]]]) | |
In [24]: torch.tensordot(a_flat, b_flat, dims=((-1,), (-1,))).shape | |
Out[24]: torch.Size([5, 1, 2]) | |
# Repare que os resultados de Out[23] e de Out[24] são iguais aos respectivos das outras 2 implementações |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment