|
import warnings |
|
import numpy as np |
|
|
|
|
|
class FakeJit(object): |
|
def __call__(self, *args, **kwargs): |
|
if kwargs: |
|
if args: |
|
raise ValueError() |
|
else: |
|
return self |
|
else: |
|
return args[0] |
|
|
|
|
|
from numba import jit as numba_jit |
|
#numba_jit = FakeJit() |
|
|
|
|
|
#---------------------------------------------------------------------- |
|
# Distance computations |
|
|
|
@numba_jit |
|
def rdist(X1, i1, X2, i2): |
|
d = 0 |
|
for k in range(X1.shape[1]): |
|
tmp = (X1[i1, k] - X2[i2, k]) |
|
d += tmp * tmp |
|
return d |
|
|
|
|
|
@numba_jit |
|
def min_rdist(node_centroids, node_radius, i_node, X, j): |
|
d = rdist(node_centroids, i_node, X, j) |
|
return max(0, np.sqrt(d) - node_radius[i_node]) ** 2 |
|
|
|
|
|
#---------------------------------------------------------------------- |
|
# Heap for distances and neighbors |
|
|
|
def heap_create(N, k): |
|
distances = np.full((N, k), np.inf, dtype=float) |
|
indices = np.zeros((N, k), dtype=int) |
|
return distances, indices |
|
|
|
|
|
def heap_sort(distances, indices): |
|
i = np.arange(len(distances), dtype=int)[:, None] |
|
j = np.argsort(distances, 1) |
|
return distances[i, j], indices[i, j] |
|
|
|
|
|
@numba_jit |
|
def heap_push(row, val, i_val, distances, indices): |
|
size = distances.shape[1] |
|
|
|
# check if val should be in heap |
|
if val > distances[row, 0]: |
|
return |
|
|
|
# insert val at position zero |
|
distances[row, 0] = val |
|
indices[row, 0] = i_val |
|
|
|
#descend the heap, swapping values until the max heap criterion is met |
|
i = 0 |
|
while True: |
|
ic1 = 2 * i + 1 |
|
ic2 = ic1 + 1 |
|
|
|
if ic1 >= size: |
|
break |
|
elif ic2 >= size: |
|
if distances[row, ic1] > val: |
|
i_swap = ic1 |
|
else: |
|
break |
|
elif distances[row, ic1] >= distances[row, ic2]: |
|
if val < distances[row, ic1]: |
|
i_swap = ic1 |
|
else: |
|
break |
|
else: |
|
if val < distances[row, ic2]: |
|
i_swap = ic2 |
|
else: |
|
break |
|
|
|
distances[row, i] = distances[row, i_swap] |
|
indices[row, i] = indices[row, i_swap] |
|
|
|
i = i_swap |
|
|
|
distances[row, i] = val |
|
indices[row, i] = i_val |
|
|
|
#---------------------------------------------------------------------- |
|
# Tools for building the tree |
|
|
|
@numba_jit |
|
def _partition_indices(data, idx_array, idx_start, idx_end, split_index): |
|
# Find the split dimension |
|
n_features = data.shape[1] |
|
|
|
split_dim = 0 |
|
max_spread = 0 |
|
|
|
for j in range(n_features): |
|
max_val = -np.inf |
|
min_val = np.inf |
|
for i in range(idx_start, idx_end): |
|
val = data[idx_array[i], j] |
|
max_val = max(max_val, val) |
|
min_val = min(min_val, val) |
|
if max_val - min_val > max_spread: |
|
max_spread = max_val - min_val |
|
split_dim = j |
|
|
|
# Partition using the split dimension |
|
left = idx_start |
|
right = idx_end - 1 |
|
|
|
while True: |
|
midindex = left |
|
for i in range(left, right): |
|
d1 = data[idx_array[i], split_dim] |
|
d2 = data[idx_array[right], split_dim] |
|
if d1 < d2: |
|
tmp = idx_array[i] |
|
idx_array[i] = idx_array[midindex] |
|
idx_array[midindex] = tmp |
|
midindex += 1 |
|
tmp = idx_array[midindex] |
|
idx_array[midindex] = idx_array[right] |
|
idx_array[right] = tmp |
|
if midindex == split_index: |
|
break |
|
elif midindex < split_index: |
|
left = midindex + 1 |
|
else: |
|
right = midindex - 1 |
|
|
|
|
|
@numba_jit |
|
def _recursive_build(i_node, idx_start, idx_end, |
|
data, node_centroids, node_radius, idx_array, |
|
node_idx_start, node_idx_end, node_is_leaf, |
|
n_nodes, leaf_size): |
|
# determine Node centroid |
|
for j in range(data.shape[1]): |
|
node_centroids[i_node, j] = 0 |
|
for i in range(idx_start, idx_end): |
|
node_centroids[i_node, j] += data[idx_array[i], j] |
|
node_centroids[i_node, j] /= (idx_end - idx_start) |
|
|
|
# determine Node radius |
|
sq_radius = 0.0 |
|
for i in range(idx_start, idx_end): |
|
sq_dist = rdist(node_centroids, i_node, data, idx_array[i]) |
|
if sq_dist > sq_radius: |
|
sq_radius = sq_dist |
|
|
|
# set node properties |
|
node_radius[i_node] = np.sqrt(sq_radius) |
|
node_idx_start[i_node] = idx_start |
|
node_idx_end[i_node] = idx_end |
|
|
|
i_child = 2 * i_node + 1 |
|
|
|
# recursively create subnodes |
|
if i_child >= n_nodes: |
|
node_is_leaf[i_node] = True |
|
if idx_end - idx_start > 2 * leaf_size: |
|
# this shouldn't happen if our memory allocation is correct. |
|
# We'll proactively prevent memory errors, but raise a |
|
# warning saying we're doing so. |
|
warnings.warn("Internal: memory layout is flawed: " |
|
"not enough nodes allocated") |
|
pass |
|
|
|
elif idx_end - idx_start < 2: |
|
# again, this shouldn't happen if our memory allocation is correct. |
|
warnings.warn("Internal: memory layout is flawed: " |
|
"too many nodes allocated") |
|
node_is_leaf[i_node] = True |
|
|
|
else: |
|
# split node and recursively construct child nodes. |
|
node_is_leaf[i_node] = False |
|
n_mid = int((idx_end + idx_start) // 2) |
|
_partition_indices(data, idx_array, idx_start, idx_end, n_mid) |
|
_recursive_build(i_child, idx_start, n_mid, |
|
data, node_centroids, node_radius, idx_array, |
|
node_idx_start, node_idx_end, node_is_leaf, |
|
n_nodes, leaf_size) |
|
_recursive_build(i_child + 1, n_mid, idx_end, |
|
data, node_centroids, node_radius, idx_array, |
|
node_idx_start, node_idx_end, node_is_leaf, |
|
n_nodes, leaf_size) |
|
|
|
|
|
#---------------------------------------------------------------------- |
|
# Tools for querying the tree |
|
@numba_jit |
|
def _query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB, |
|
data, idx_array, node_centroids, node_radius, |
|
node_is_leaf, node_idx_start, node_idx_end): |
|
#------------------------------------------------------------ |
|
# Case 1: query point is outside node radius: |
|
# trim it from the query |
|
if sq_dist_LB > heap_distances[i_pt, 0]: |
|
pass |
|
|
|
#------------------------------------------------------------ |
|
# Case 2: this is a leaf node. Update set of nearby points |
|
elif node_is_leaf[i_node]: |
|
for i in range(node_idx_start[i_node], |
|
node_idx_end[i_node]): |
|
dist_pt = rdist(data, idx_array[i], X, i_pt) |
|
if dist_pt < heap_distances[i_pt, 0]: |
|
heap_push(i_pt, dist_pt, idx_array[i], |
|
heap_distances, heap_indices) |
|
|
|
#------------------------------------------------------------ |
|
# Case 3: Node is not a leaf. Recursively query subnodes |
|
# starting with the closest |
|
else: |
|
i1 = 2 * i_node + 1 |
|
i2 = i1 + 1 |
|
sq_dist_LB_1 = min_rdist(node_centroids, |
|
node_radius, |
|
i1, X, i_pt) |
|
sq_dist_LB_2 = min_rdist(node_centroids, |
|
node_radius, |
|
i2, X, i_pt) |
|
|
|
# recursively query subnodes |
|
if sq_dist_LB_1 <= sq_dist_LB_2: |
|
_query_recursive(i1, X, i_pt, heap_distances, |
|
heap_indices, sq_dist_LB_1, |
|
data, idx_array, node_centroids, node_radius, |
|
node_is_leaf, node_idx_start, node_idx_end) |
|
_query_recursive(i2, X, i_pt, heap_distances, |
|
heap_indices, sq_dist_LB_2, |
|
data, idx_array, node_centroids, node_radius, |
|
node_is_leaf, node_idx_start, node_idx_end) |
|
else: |
|
_query_recursive(i2, X, i_pt, heap_distances, |
|
heap_indices, sq_dist_LB_2, |
|
data, idx_array, node_centroids, node_radius, |
|
node_is_leaf, node_idx_start, node_idx_end) |
|
_query_recursive(i1, X, i_pt, heap_distances, |
|
heap_indices, sq_dist_LB_1, |
|
data, idx_array, node_centroids, node_radius, |
|
node_is_leaf, node_idx_start, node_idx_end) |
|
|
|
|
|
#---------------------------------------------------------------------- |
|
# The Ball Tree object |
|
class BallTree(object): |
|
def __init__(self, data, leaf_size=40): |
|
self.data = data |
|
self.leaf_size = leaf_size |
|
|
|
# validate data |
|
if self.data.size == 0: |
|
raise ValueError("X is an empty array") |
|
|
|
if leaf_size < 1: |
|
raise ValueError("leaf_size must be greater than or equal to 1") |
|
|
|
self.n_samples = self.data.shape[0] |
|
self.n_features = self.data.shape[1] |
|
|
|
# determine number of levels in the tree, and from this |
|
# the number of nodes in the tree. This results in leaf nodes |
|
# with numbers of points betweeen leaf_size and 2 * leaf_size |
|
self.n_levels = 1 + np.log2(max(1, ((self.n_samples - 1) |
|
// self.leaf_size))) |
|
self.n_nodes = int(2 ** self.n_levels) - 1 |
|
|
|
# allocate arrays for storage |
|
self.idx_array = np.arange(self.n_samples, dtype=int) |
|
self.node_radius = np.zeros(self.n_nodes, dtype=float) |
|
self.node_idx_start = np.zeros(self.n_nodes, dtype=int) |
|
self.node_idx_end = np.zeros(self.n_nodes, dtype=int) |
|
self.node_is_leaf = np.zeros(self.n_nodes, dtype=int) |
|
self.node_centroids = np.zeros((self.n_nodes, self.n_features), |
|
dtype=float) |
|
|
|
# Allocate tree-specific data from TreeBase |
|
_recursive_build(0, 0, self.n_samples, |
|
self.data, self.node_centroids, |
|
self.node_radius, self.idx_array, |
|
self.node_idx_start, self.node_idx_end, |
|
self.node_is_leaf, self.n_nodes, self.leaf_size) |
|
|
|
def query(self, X, k=1, sort_results=True): |
|
X = np.asarray(X, dtype=float) |
|
|
|
if X.shape[-1] != self.n_features: |
|
raise ValueError("query data dimension must " |
|
"match training data dimension") |
|
|
|
if self.data.shape[0] < k: |
|
raise ValueError("k must be less than or equal " |
|
"to the number of training points") |
|
|
|
# flatten X, and save original shape information |
|
Xshape = X.shape |
|
X = X.reshape((-1, self.data.shape[1])) |
|
|
|
# initialize heap for neighbors |
|
heap_distances, heap_indices = heap_create(X.shape[0], k) |
|
|
|
for i in range(X.shape[0]): |
|
sq_dist_LB = min_rdist(self.node_centroids, |
|
self.node_radius, |
|
0, X, i) |
|
_query_recursive(0, X, i, heap_distances, heap_indices, sq_dist_LB, |
|
self.data, self.idx_array, self.node_centroids, |
|
self.node_radius, self.node_is_leaf, |
|
self.node_idx_start, self.node_idx_end) |
|
|
|
distances, indices = heap_sort(heap_distances, heap_indices) |
|
distances = np.sqrt(distances) |
|
|
|
# deflatten results |
|
return (distances.reshape(Xshape[:-1] + (k,)), |
|
indices.reshape(Xshape[:-1] + (k,))) |
|
|
|
|
|
#---------------------------------------------------------------------- |
|
# Testing function |
|
|
|
def test_tree(N=1000, D=3, K=5, LS=40): |
|
from time import time |
|
from sklearn.neighbors import BallTree as skBallTree |
|
|
|
rseed = np.random.randint(10000) |
|
print("-------------------------------------------------------") |
|
print("{0} neighbors of {1} points in {2} dimensions".format(K, N, D)) |
|
print("random seed = {0}".format(rseed)) |
|
np.random.seed(rseed) |
|
X = np.random.random((N, D)) |
|
|
|
# pre-run to jit compile the code |
|
BallTree(X, leaf_size=LS).query(X, K) |
|
|
|
t0 = time() |
|
bt1 = skBallTree(X, leaf_size=LS) |
|
t1 = time() |
|
dist1, ind1 = bt1.query(X, K) |
|
t2 = time() |
|
|
|
bt2 = BallTree(X, leaf_size=LS) |
|
t3 = time() |
|
dist2, ind2 = bt2.query(X, K) |
|
t4 = time() |
|
|
|
print("results match: {0} {1}".format(np.allclose(dist1, dist2), |
|
np.allclose(ind1, ind2))) |
|
print("") |
|
print("sklearn build: {0:.2g} sec".format(t1 - t0)) |
|
print("numba build : {0:.2g} sec".format(t3 - t2)) |
|
print("") |
|
print("sklearn query: {0:.2g} sec".format(t2 - t1)) |
|
print("numba query : {0:.2g} sec".format(t4 - t3)) |
|
|
|
|
|
if __name__ == '__main__': |
|
test_tree() |
Hi Jake,
I realize this gist is almost 2 years old now, but I was looking at implementing the Ball Tree algorithm using numba and found this wonderful piece of code. At the time, you mentioned that numba was about 10x slower than the cython code in scikit-learn. A lot has happened in the last two years, with numba. It looks like inlining has improved for recursion purposes, but more importantly, the ParallelAccelerator functionality now allows you to get free multi-threading by modifying lines 316-323 in ball_tree_numba.py. As a result, here are the results of your code (w/ my modifications/enhancements) on my 2013 MacBook Pro w/ Python 3.6:
New function to take advantage of parallelized query:
Please let me know of a good way to provide you my updated script and I will be happy to do so. I tried to attach it to this comment but it looks like code file attachments are not supported.
Thanks again for providing this awesome gist!!
David