Skip to content

Instantly share code, notes, and snippets.

@yangkky
Created November 10, 2018 23:51
Show Gist options
  • Save yangkky/551ae27e42c76c7421bc0f919400ba03 to your computer and use it in GitHub Desktop.
Save yangkky/551ae27e42c76c7421bc0f919400ba03 to your computer and use it in GitHub Desktop.
X_list = [[1, 0, 1, 0], [1, 1, 0, 0], [0, 1, 1, 1]]
X1 = torch.tensor(X_list)
X_list = [[1, 0, 1, 1], [1, 1, 1, 0], [1, 0, 1, 1], [1, 1, 1, 0]]
X2 = torch.tensor(X_list)
b1 = 3
b2 = 4
n = 2
d = 32
e1 = torch.randn(size=(b1, n, d))
e2 = torch.randn(size=(b2, n, d))
def fancy_inds(e1, e2, X1, X2):
b1 = len(X1)
b2 = len(X2)
inds2 = torch.cat([torch.arange(b2)] *b1).long()
inds1 = torch.cat([torch.tensor([i for _ in range(b2)]) for i in range(b1)]).long()
S = torch.cat([e1[inds1], e2[inds2]], dim=-1)
S = S @ S.transpose(-1, -2)
subs = torch.stack([S[i][x1, x2] for i, (x1, x2) in enumerate(itertools.product(X1, X2))])
return subs
fancy_inds(e1, e2, X1, X2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment