Created
November 10, 2018 23:51
-
-
Save yangkky/551ae27e42c76c7421bc0f919400ba03 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
X_list = [[1, 0, 1, 0], [1, 1, 0, 0], [0, 1, 1, 1]] | |
X1 = torch.tensor(X_list) | |
X_list = [[1, 0, 1, 1], [1, 1, 1, 0], [1, 0, 1, 1], [1, 1, 1, 0]] | |
X2 = torch.tensor(X_list) | |
b1 = 3 | |
b2 = 4 | |
n = 2 | |
d = 32 | |
e1 = torch.randn(size=(b1, n, d)) | |
e2 = torch.randn(size=(b2, n, d)) | |
def fancy_inds(e1, e2, X1, X2): | |
b1 = len(X1) | |
b2 = len(X2) | |
inds2 = torch.cat([torch.arange(b2)] *b1).long() | |
inds1 = torch.cat([torch.tensor([i for _ in range(b2)]) for i in range(b1)]).long() | |
S = torch.cat([e1[inds1], e2[inds2]], dim=-1) | |
S = S @ S.transpose(-1, -2) | |
subs = torch.stack([S[i][x1, x2] for i, (x1, x2) in enumerate(itertools.product(X1, X2))]) | |
return subs | |
fancy_inds(e1, e2, X1, X2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment