Skip to content

Instantly share code, notes, and snippets.

@straussmaximilian
Created September 21, 2019 13:03
Show Gist options
  • Save straussmaximilian/57a84d2f660dfea660f3eec8cf8175b8 to your computer and use it in GitHub Desktop.
Save straussmaximilian/57a84d2f660dfea660f3eec8cf8175b8 to your computer and use it in GitHub Desktop.
@njit
def boolean_index_numba_multiple(array, xmin, xmax, ymin, ymax, zmin, zmax):
"""
Takes a numpy array and isolates all points that are within [xmin, xmax]
for the first dimension, between [ymin, ymax] for the second dimension
and [zmin, zmax] for the third dimension by creating a boolean index.
This function will be compiled with numba.
"""
index = ((array[:, 0] > xmin) & (array[:, 1] > ymin) & (array[:, 2] > zmin)
& (array[:, 0] < xmax) & (array[:, 1] < ymax) & (array[:, 2] < zmax))
return array[index]
def multiple_queries(data, delta=0.1):
"""
Takes an array and a list of query points to filter points that are within
a delta to the query points.
"""
array, query_points = data
count = 0
for i in range(len(query_points)):
point = query_points[i]
xmin, xmax = point[0]-delta, point[0]+delta
ymin, ymax = point[1]-delta, point[1]+delta
zmin, zmax = point[2]-delta, point[2]+delta
filtered_list = boolean_index_numba_multiple(array, xmin, xmax, ymin, ymax, zmin, zmax)
count += len(filtered_list)
return count
def multiple_queries_index(data, delta=0.1):
"""
Takes an array and a list of query points to filter points that are within a delta to the query points.
Sorts the array beforehand and slices a subarray.
"""
array, query_points = data
# Sort the array on the first dimension
sorted_array = array[np.argsort(array[:, 0])]
count = 0
for point in query_points:
xmin, xmax = point[0]-delta, point[0]+delta
ymin, ymax = point[1]-delta, point[1]+delta
zmin, zmax = point[2]-delta, point[2]+delta
min_index = np.searchsorted(sorted_array[:, 0], xmin, side='left')
max_index = np.searchsorted(sorted_array[:, 0], xmax, side='right')
filtered_list = boolean_index_numba_multiple(sorted_array[min_index:max_index], xmin, xmax, ymin, ymax, zmin, zmax)
count += len(filtered_list)
return count
array = random_array(1e5,3)
query_points = random_array(1e3, 3)
data = (array, query_points)
print('Multiple queries:\t\t', end='')
%timeit multiple_queries(data)
print('Multiple queries with subset:\t', end='')
%timeit multiple_queries_index(data)
print('Count for multiple_queries: {:,}'.format(multiple_queries(data)))
print('Count for multiple_queries: {:,}'.format(multiple_queries_index(data)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment