Last active
June 24, 2021 21:22
-
-
Save brandonwillard/29666a1864d1b9572c41da59d830f4e1 to your computer and use it in GitHub Desktop.
A script that runs comparisons of Aesara implementations of log-sum-exp compiled to Numba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
import numba | |
import numpy as np | |
import pandas as pd | |
import aesara | |
import aesara.tensor as at | |
test_data = [ | |
np.random.normal(size=(3, 3)), | |
np.random.normal(size=(25, 25)), | |
np.random.normal(size=(100, 100)), | |
np.random.normal(size=(1000, 1000)), | |
np.random.normal(size=(10000, 10000)), | |
] | |
@numba.vectorize | |
def custom_op_fn(x): | |
if np.isinf(x): | |
return 0 | |
else: | |
return x | |
# Make sure the underlying Numba function is compiled | |
custom_op_res = custom_op_fn(test_data[0]) | |
X = at.matrix("X") | |
Y = at.switch(at.isinf(X), 0, X) | |
aesara_numba_fn = aesara.function([X], Y, mode="NUMBA") | |
aesara_c_fn = aesara.function([X], Y, mode="FAST_RUN") | |
fn, *_ = aesara_numba_fn.maker.linker.make_all() | |
cl_vars = inspect.getclosurevars(fn) | |
thunk = cl_vars.nonlocals["thunks"][0] | |
thunk_signature = inspect.signature(thunk) | |
aesara_numba_direct_fn = thunk_signature.parameters["fgraph_jit"].default | |
# Make sure the underlying Numba function is compiled | |
aesara_numba_res = aesara_numba_fn(test_data[0]) | |
aesara_numba_direct_res, = aesara_numba_direct_fn(test_data[0]) | |
# Make sure the custom `Op` and the Aesara graph are equivalent | |
np.testing.assert_array_almost_equal(custom_op_res, aesara_numba_res) | |
timeit_data = pd.DataFrame( | |
columns=["Numba", "Aesara-C", "Aesara-Numba", "Aesara-Numba (direct)"], | |
index=pd.Index([], name="data shape"), | |
) | |
def format_result(x): | |
return str(x).split(" per")[0] + f" ({x.loops})" | |
for data in test_data: | |
print(f"Running data with shape={data.shape}") | |
numba_time = get_ipython().run_line_magic("timeit", "-o custom_op_fn(data)") | |
aesara_c_time = get_ipython().run_line_magic("timeit", "-o aesara_c_fn(data)") | |
aesara_numba_time = get_ipython().run_line_magic( | |
"timeit", "-o aesara_numba_fn(data)" | |
) | |
aesara_numba_direct_time = get_ipython().run_line_magic( | |
"timeit", "-o aesara_numba_direct_fn(data)" | |
) | |
timeit_data.loc[str(data.shape)] = [ | |
format_result(r) for r in [numba_time, aesara_c_time, aesara_numba_time, aesara_numba_direct_time] | |
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numba | |
import numpy as np | |
import pandas as pd | |
import aesara | |
import aesara.tensor as at | |
@numba.njit(parallel=True, fastmath=True) | |
def numba_logsumexp(p, out): | |
n, m = p.shape | |
assert len(out) == n | |
assert out.ndim == 1 | |
assert p.ndim == 2 | |
for i in numba.prange(n): | |
res = 0 | |
for j in range(m): | |
res += np.exp(p[i, j]) | |
res = np.log(res) | |
out[i] = res | |
@numba.njit(parallel=True, fastmath=True) | |
def numba_logsumexp_grad(p, out, dout, dp): | |
n, m = p.shape | |
assert len(out) == n | |
assert out.ndim == 1 | |
assert len(dout) == n | |
assert dout.ndim == 1 | |
assert dp.shape == p.shape | |
for i in numba.prange(n): | |
for j in range(m): | |
dp[i, j] = np.exp(p[i, j] - out[i]) * dout[i] | |
class LogSumExp(aesara.graph.op.Op): | |
"""Custom softmax, done through logsumexp""" | |
itypes = [at.dmatrix] | |
otypes = [at.dvector] | |
def perform(self, node, inputs, outputs): | |
(x,) = inputs | |
n, m = x.shape | |
out = np.zeros(n, dtype=x.dtype) | |
numba_logsumexp(x, out) | |
outputs[0][0] = out | |
def grad(self, inputs, grads): | |
(x,) = inputs | |
(dout,) = grads | |
logsumexp = self(x) | |
return [LogSumExpGrad()(x, logsumexp, dout)] | |
class LogSumExpGrad(aesara.graph.op.Op): | |
"""Joint operator""" | |
itypes = [at.dmatrix, at.dvector, at.dvector] | |
otypes = [at.dmatrix] | |
def perform(self, node, inputs, outputs): | |
p, out, dout = inputs | |
dp = np.zeros(p.shape, dtype=p.dtype) | |
numba_logsumexp_grad(p, out, dout, dp) | |
outputs[0][0] = dp | |
logsumexp = LogSumExp() | |
test_data = [ | |
np.random.normal(size=(3, 3)), | |
np.random.normal(size=(25, 25)), | |
np.random.normal(size=(100, 100)), | |
np.random.normal(size=(1000, 1000)), | |
np.random.normal(size=(10000, 10000)), | |
] | |
X = at.matrix("X") | |
custom_op_fn = aesara.function([X], logsumexp(X)) | |
# Make sure the underlying Numba function is compiled | |
custom_op_res = custom_op_fn(test_data[0]) | |
def logsumexp2(x, axis=None, keepdims=True): | |
x_max = at.max(x, axis=axis, keepdims=True) | |
x_max = at.switch(at.isinf(x_max), 0, x_max) | |
res = at.log(at.sum(at.exp(x - x_max), axis=axis, keepdims=True)) + x_max | |
return res if keepdims else res.squeeze() | |
aesara_numba_fn = aesara.function( | |
[X], logsumexp2(X, axis=1, keepdims=False), mode="NUMBA" | |
) | |
aesara_c_fn = aesara.function( | |
[X], logsumexp2(X, axis=1, keepdims=False), mode="FAST_RUN" | |
) | |
# Make sure the underlying Numba function is compiled | |
aesara_numba_res = aesara_numba_fn(test_data[0]) | |
# Make sure the custom `Op` and the Aesara graph are equivalent | |
np.testing.assert_array_almost_equal(custom_op_res, aesara_numba_res) | |
timeit_data = pd.DataFrame( | |
columns=["custom `Op`", "Aesara graph (Numba)", "Aesara graph (C)"], | |
index=pd.Index([], name="data shape"), | |
) | |
def format_result(x): | |
return str(x).split(" per")[0] + f" ({x.loops})" | |
for data in test_data: | |
print(f"Running data with shape={data.shape}") | |
custom_op_time = get_ipython().run_line_magic("timeit", "-o custom_op_fn(data)") | |
aesara_numba_time = get_ipython().run_line_magic( | |
"timeit", "-o aesara_numba_fn(data)" | |
) | |
aesara_c_time = get_ipython().run_line_magic("timeit", "-o aesara_c_fn(data)") | |
timeit_data.loc[str(data.shape)] = [ | |
format_result(r) for r in [custom_op_time, aesara_numba_time, aesara_c_time] | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment