Created
September 21, 2019 13:03
-
-
Save straussmaximilian/57a84d2f660dfea660f3eec8cf8175b8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@njit | |
def boolean_index_numba_multiple(array, xmin, xmax, ymin, ymax, zmin, zmax): | |
""" | |
Takes a numpy array and isolates all points that are within [xmin, xmax] | |
for the first dimension, between [ymin, ymax] for the second dimension | |
and [zmin, zmax] for the third dimension by creating a boolean index. | |
This function will be compiled with numba. | |
""" | |
index = ((array[:, 0] > xmin) & (array[:, 1] > ymin) & (array[:, 2] > zmin) | |
& (array[:, 0] < xmax) & (array[:, 1] < ymax) & (array[:, 2] < zmax)) | |
return array[index] | |
def multiple_queries(data, delta=0.1): | |
""" | |
Takes an array and a list of query points to filter points that are within | |
a delta to the query points. | |
""" | |
array, query_points = data | |
count = 0 | |
for i in range(len(query_points)): | |
point = query_points[i] | |
xmin, xmax = point[0]-delta, point[0]+delta | |
ymin, ymax = point[1]-delta, point[1]+delta | |
zmin, zmax = point[2]-delta, point[2]+delta | |
filtered_list = boolean_index_numba_multiple(array, xmin, xmax, ymin, ymax, zmin, zmax) | |
count += len(filtered_list) | |
return count | |
def multiple_queries_index(data, delta=0.1): | |
""" | |
Takes an array and a list of query points to filter points that are within a delta to the query points. | |
Sorts the array beforehand and slices a subarray. | |
""" | |
array, query_points = data | |
# Sort the array on the first dimension | |
sorted_array = array[np.argsort(array[:, 0])] | |
count = 0 | |
for point in query_points: | |
xmin, xmax = point[0]-delta, point[0]+delta | |
ymin, ymax = point[1]-delta, point[1]+delta | |
zmin, zmax = point[2]-delta, point[2]+delta | |
min_index = np.searchsorted(sorted_array[:, 0], xmin, side='left') | |
max_index = np.searchsorted(sorted_array[:, 0], xmax, side='right') | |
filtered_list = boolean_index_numba_multiple(sorted_array[min_index:max_index], xmin, xmax, ymin, ymax, zmin, zmax) | |
count += len(filtered_list) | |
return count | |
array = random_array(1e5,3) | |
query_points = random_array(1e3, 3) | |
data = (array, query_points) | |
print('Multiple queries:\t\t', end='') | |
%timeit multiple_queries(data) | |
print('Multiple queries with subset:\t', end='') | |
%timeit multiple_queries_index(data) | |
print('Count for multiple_queries: {:,}'.format(multiple_queries(data))) | |
print('Count for multiple_queries: {:,}'.format(multiple_queries_index(data))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment