Last active
June 1, 2023 13:27
-
-
Save apatlpo/a2be40ab48b500cfff0eb6e53f1f8bfd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import timeit | |
import numpy as np | |
from numba import njit, int32, float64, prange | |
# benchmarks | |
def generate_data(n0, n1): | |
print(f"Generate data with sizes: {n0} and {n1}") | |
ds = {} | |
for i, n in zip(["0", "1"], [n0, n1]): | |
for v in ["x", "y", "z", "v"]: | |
ds[v+i] = np.random.randn(n) | |
return ds | |
@njit(int32(float64, float64, float64, int32)) | |
def compute_bin(v, edge_min, edge_max, bin_num): | |
""" assuming regular bins, compute bin index based on bin extreme bounds and number""" | |
# note: int leads to inappropriate rounding of negative values | |
#i = int(bin_num * (v - edge_min) / (edge_max - edge_min)) | |
#i = np.floor(bin_num * (v - edge_min) / (edge_max - edge_min), dtype=np.int32) | |
i = np.floor(bin_num * (v - edge_min) / (edge_max - edge_min)) | |
if i < 0 or i >= bin_num: | |
return -1 | |
else: | |
return i | |
# is compute_bin necessary given np.digitize / np.searchsorted (with vectorization possible over i1) ? | |
@njit() | |
def get_bin_info(bin_edges): | |
""" extract useful information from edge array """ | |
num = len(bin_edges)-1 | |
edge_min = bin_edges[0] | |
edge_max = bin_edges[-1] | |
return edge_min, edge_max, num | |
### | |
@njit() #, boundscheck=True | |
def cx_1d_xprod( | |
val0, val1, | |
x0, x1, dx_bin_edges, dx_tolerance, | |
): | |
# get input array sizes and bin info | |
n0, n1 = len(val0), len(val1) | |
dx_min, dx_max, dx_num = get_bin_info(dx_bin_edges) | |
# instantiate output arrays | |
out = np.zeros(dx_num, dtype=np.float64) | |
count = np.zeros(dx_num, dtype=np.intp) | |
# start looping | |
for i0 in range(n0): | |
for i1 in range(n1): | |
dx = x1[i1]-x0[i0] | |
k = compute_bin(dx, dx_min, dx_max, dx_num) | |
if k!=-1 and abs(dx)>dx_tolerance: | |
out[k] += val0[i0]*val1[i1] | |
count[k] += 1 | |
return out, count | |
### | |
@njit() #, boundscheck=True | |
def cx_1d_par( | |
val0, val1, | |
x0, x1, dx_bin_edges, dx_tolerance, | |
): | |
# get input array sizes and bin info | |
n0, n1 = len(val0), len(val1) | |
dx_min, dx_max, dx_num = get_bin_info(dx_bin_edges) | |
# instantiate output arrays | |
out = np.zeros(n0*n1, dtype=np.float64) | |
index = np.zeros(n0*n1, dtype=np.intp) | |
outg = np.zeros(dx_num, dtype=np.float64) | |
countg = np.zeros(dx_num, dtype=np.intp) | |
# start looping | |
out, index = _par_core(val0, val1, x0, x1, dx_min, dx_max, dx_num, dx_tolerance) | |
outg, countg = _par_add(dx_num, index, out) | |
return outg, countg | |
@njit(parallel=True) #, boundscheck=True | |
def _par_core( | |
val0, val1, | |
x0, x1, dx_min, dx_max, dx_num, dx_tolerance, | |
): | |
# get input array sizes and bin info | |
n0, n1 = len(val0), len(val1) | |
# instantiate output arrays | |
out = np.zeros(n0*n1, dtype=np.float64) | |
index = np.zeros(n0*n1, dtype=np.intp) | |
# start looping | |
i = 0 | |
for i0 in prange(n0): | |
for i1 in range(n1): | |
dx = x1[i1]-x0[i0] | |
k = compute_bin(dx, dx_min, dx_max, dx_num) | |
if k!=-1 and abs(dx)>dx_tolerance: | |
index[i] = k | |
out[i] = val0[i0]*val1[i1] | |
else: | |
index[i] = -1 | |
i+=1 | |
return out, index | |
@njit(parallel=False) #, boundscheck=True | |
def _par_add( | |
n, | |
index, out | |
): | |
out_global = np.zeros(n, dtype=np.float64) | |
count_global = np.zeros(n, dtype=np.intp) | |
for i in range(len(index)): | |
if index[i]>-1: | |
out_global[index[i]] += out[i] | |
count_global[index[i]] += 1 | |
return out_global, count_global | |
### | |
if __name__=="__main__": | |
ds = generate_data(10_000, 10_000) | |
bins = np.arange(-2,2,.05) | |
run = lambda : cx_1d_xprod(ds["v0"], ds["v1"], ds["x0"], ds["x1"], bins, 0) | |
out_true, count_true = run() | |
run = lambda : cx_1d_par(ds["v0"], ds["v1"], ds["x0"], ds["x1"], bins, 0) | |
out_par, count_par = run() | |
#print(count_true, count_par) | |
np.testing.assert_equal(out_true, out_par) | |
np.testing.assert_equal(count_true, count_par) | |
print("done") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment