Created
September 22, 2017 22:20
-
-
Save mrdrozdov/75f04e2ffaf527cb2083d6c711c85320 to your computer and use it in GitHub Desktop.
Cartesian Product with Broadcasting
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
batch_size, data_size, output_size = 64, 784, 10 | |
x1 = numpy.random.rand(batch_size, data_size) | |
x2 = numpy.random.rand(batch_size, data_size) | |
w = numpy.random.rand(data_size * 2, output_size) | |
b = numpy.random.rand(output_size) | |
def cartesian_product_tile(x1, x2): | |
""" Computes the cartesian product `y` of `x1` and `x2`. Each `y_{i,j}` | |
of matrix `y` will be: | |
y_{0,0} = f(x1_0; x2_0) | |
y_{0,1} = f(x1_0; x2_1) | |
... | |
y_{0,n} = f(x1_0; x2_n) | |
... | |
y_{m,n} = f(x1_m; x2_n) | |
This method returns the flattened vector representation of matrix `y`. | |
""" | |
x1_tiled = numpy.tile(x1, (1, x2.shape[0])).reshape(x1.shape[0] * x2.shape[0], data_size) | |
x2_tiled = numpy.tile(x2, (x1.shape[0], 1)) | |
x = numpy.concatenate([x1_tiled, x2_tiled], axis=1) | |
y = numpy.dot(x, w) + b | |
return y | |
def cartesian_product_broadcast(x1, x2): | |
""" Computes the cartesian product `y` of `x1` and `x2`. Each `y_{i,j}` | |
of matrix `y` will be: | |
y_{0,0} = f(x1_0; x2_0) | |
y_{0,1} = f(x1_0; x2_1) | |
... | |
y_{0,n} = f(x1_0; x2_n) | |
... | |
y_{m,n} = f(x1_m; x2_n) | |
This method returns the flattened vector representation of matrix `y`. | |
""" | |
w1, w2 = w[:data_size], w[data_size:] | |
x1_h = numpy.dot(x1, w1).reshape(x1.shape[0], 1, output_size) | |
x2_h = numpy.dot(x2, w2).reshape(1, x2.shape[0], output_size) | |
y = x1_h + x2_h + b | |
return y.reshape(x1.shape[0] * x2.shape[0], output_size) | |
y_tile = cartesian_product_tile(x1, x2) | |
y_broadcast = cartesian_product_broadcast(x1, x2) | |
assert (numpy.abs(y_tile - y_broadcast) < 1e-5).all() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment