Skip to content

Instantly share code, notes, and snippets.

@yangkky
Created November 11, 2018 07:21
Show Gist options
  • Save yangkky/6362205e45e86a18bb744db122a7305f to your computer and use it in GitHub Desktop.
Save yangkky/6362205e45e86a18bb744db122a7305f to your computer and use it in GitHub Desktop.
np.random.seed(9)
X1 = np.random.rand(3,4)*2
X1 = X1.astype(int)
X2 = np.random.rand(3,4)*2
X2 = X2.astype(int)
S = np.random.rand(9,2,2)
subs1 = [S[i][x1, x2] for i, (x1, x2) in enumerate(itertools.product(X1, X2))]
def test5(X1, X2, S):
i, n, n = S.shape
b1, L = X1.shape
b2, L2 = X2.shape
assert i == b1 * b2
assert L == L2
subs = []
c = 0
for x1 in X1:
for x2 in X2:
subs.append(S[c][x1, x2])
c += 1
return np.array(subs)
subs2 = test5(X1, X2, S)
assert np.allclose(subs1, subs2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment