|
import warnings |
|
import numpy as np |
|
|
|
from numba import jit as numba_jit |
|
import numba |
|
|
|
#---------------------------------------------------------------------- |
|
# Distance computations |
|
|
|
@numba.jit(nopython=True) |
|
def rdist(X1, i1, X2, i2): |
|
d = 0 |
|
for k in range(X1.shape[1]): |
|
tmp = (X1[i1, k] - X2[i2, k]) |
|
d += tmp * tmp |
|
return d |
|
|
|
|
|
@numba.jit(nopython=True) |
|
def min_rdist(node_centroids, node_radius, i_node, X, j): |
|
d = rdist(node_centroids, i_node, X, j) |
|
return np.square(max(0, np.sqrt(d) - node_radius[i_node])) |
|
|
|
|
|
#---------------------------------------------------------------------- |
|
# Heap for distances and neighbors |
|
|
|
@numba.jit(nopython=True) |
|
def heap_create(N, k): |
|
distances = np.full((N, k), np.finfo(np.float64).max) |
|
indices = np.zeros((N, k), dtype=np.int64) |
|
return distances, indices |
|
|
|
|
|
def heap_sort(distances, indices): |
|
i = np.arange(len(distances), dtype=int)[:, None] |
|
j = np.argsort(distances, 1) |
|
return distances[i, j], indices[i, j] |
|
|
|
|
|
@numba.jit(nopython=True) |
|
def heap_push(row, val, i_val, distances, indices): |
|
size = distances.shape[1] |
|
|
|
# check if val should be in heap |
|
if val > distances[row, 0]: |
|
return |
|
|
|
# insert val at position zero |
|
distances[row, 0] = val |
|
indices[row, 0] = i_val |
|
|
|
#descend the heap, swapping values until the max heap criterion is met |
|
i = 0 |
|
while True: |
|
ic1 = 2 * i + 1 |
|
ic2 = ic1 + 1 |
|
|
|
if ic1 >= size: |
|
break |
|
elif ic2 >= size: |
|
if distances[row, ic1] > val: |
|
i_swap = ic1 |
|
else: |
|
break |
|
elif distances[row, ic1] >= distances[row, ic2]: |
|
if val < distances[row, ic1]: |
|
i_swap = ic1 |
|
else: |
|
break |
|
else: |
|
if val < distances[row, ic2]: |
|
i_swap = ic2 |
|
else: |
|
break |
|
|
|
distances[row, i] = distances[row, i_swap] |
|
indices[row, i] = indices[row, i_swap] |
|
|
|
i = i_swap |
|
|
|
distances[row, i] = val |
|
indices[row, i] = i_val |
|
|
|
#---------------------------------------------------------------------- |
|
# Tools for building the tree |
|
|
|
@numba.jit(nopython=True) |
|
def _partition_indices(data, idx_array, idx_start, idx_end, split_index): |
|
# Find the split dimension |
|
n_features = data.shape[1] |
|
|
|
split_dim = 0 |
|
max_spread = 0 |
|
|
|
for j in range(n_features): |
|
max_val = -np.inf |
|
min_val = np.inf |
|
for i in range(idx_start, idx_end): |
|
val = data[idx_array[i], j] |
|
max_val = max(max_val, val) |
|
min_val = min(min_val, val) |
|
if max_val - min_val > max_spread: |
|
max_spread = max_val - min_val |
|
split_dim = j |
|
|
|
# Partition using the split dimension |
|
left = idx_start |
|
right = idx_end - 1 |
|
|
|
while True: |
|
midindex = left |
|
for i in range(left, right): |
|
d1 = data[idx_array[i], split_dim] |
|
d2 = data[idx_array[right], split_dim] |
|
if d1 < d2: |
|
tmp = idx_array[i] |
|
idx_array[i] = idx_array[midindex] |
|
idx_array[midindex] = tmp |
|
midindex += 1 |
|
tmp = idx_array[midindex] |
|
idx_array[midindex] = idx_array[right] |
|
idx_array[right] = tmp |
|
if midindex == split_index: |
|
break |
|
elif midindex < split_index: |
|
left = midindex + 1 |
|
else: |
|
right = midindex - 1 |
|
|
|
|
|
@numba.jit(nopython=True) |
|
def _recursive_build(i_node, idx_start, idx_end, |
|
data, node_centroids, node_radius, idx_array, |
|
node_idx_start, node_idx_end, node_is_leaf, |
|
n_nodes, leaf_size): |
|
# determine Node centroid |
|
for j in range(data.shape[1]): |
|
node_centroids[i_node, j] = 0 |
|
for i in range(idx_start, idx_end): |
|
node_centroids[i_node, j] += data[idx_array[i], j] |
|
node_centroids[i_node, j] /= (idx_end - idx_start) |
|
|
|
# determine Node radius |
|
sq_radius = 0.0 |
|
for i in range(idx_start, idx_end): |
|
sq_dist = rdist(node_centroids, i_node, data, idx_array[i]) |
|
if sq_dist > sq_radius: |
|
sq_radius = sq_dist |
|
|
|
# set node properties |
|
node_radius[i_node] = np.sqrt(sq_radius) |
|
node_idx_start[i_node] = idx_start |
|
node_idx_end[i_node] = idx_end |
|
|
|
i_child = 2 * i_node + 1 |
|
|
|
# recursively create subnodes |
|
if i_child >= n_nodes: |
|
node_is_leaf[i_node] = True |
|
if idx_end - idx_start > 2 * leaf_size: |
|
# this shouldn't happen if our memory allocation is correct. |
|
# We'll proactively prevent memory errors, but raise a |
|
# warning saying we're doing so. |
|
#warnings.warn("Internal: memory layout is flawed: " |
|
# "not enough nodes allocated") |
|
pass |
|
|
|
elif idx_end - idx_start < 2: |
|
# again, this shouldn't happen if our memory allocation is correct. |
|
#warnings.warn("Internal: memory layout is flawed: " |
|
# "too many nodes allocated") |
|
node_is_leaf[i_node] = True |
|
|
|
else: |
|
# split node and recursively construct child nodes. |
|
node_is_leaf[i_node] = False |
|
n_mid = int((idx_end + idx_start) // 2) |
|
_partition_indices(data, idx_array, idx_start, idx_end, n_mid) |
|
_recursive_build(i_child, idx_start, n_mid, |
|
data, node_centroids, node_radius, idx_array, |
|
node_idx_start, node_idx_end, node_is_leaf, |
|
n_nodes, leaf_size) |
|
_recursive_build(i_child + 1, n_mid, idx_end, |
|
data, node_centroids, node_radius, idx_array, |
|
node_idx_start, node_idx_end, node_is_leaf, |
|
n_nodes, leaf_size) |
|
|
|
|
|
#---------------------------------------------------------------------- |
|
# Tools for querying the tree |
|
@numba.jit(nopython=True) |
|
def _query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB, |
|
data, idx_array, node_centroids, node_radius, |
|
node_is_leaf, node_idx_start, node_idx_end): |
|
#------------------------------------------------------------ |
|
# Case 1: query point is outside node radius: |
|
# trim it from the query |
|
if sq_dist_LB > heap_distances[i_pt, 0]: |
|
pass |
|
|
|
#------------------------------------------------------------ |
|
# Case 2: this is a leaf node. Update set of nearby points |
|
elif node_is_leaf[i_node]: |
|
for i in range(node_idx_start[i_node], |
|
node_idx_end[i_node]): |
|
dist_pt = rdist(data, idx_array[i], X, i_pt) |
|
if dist_pt < heap_distances[i_pt, 0]: |
|
heap_push(i_pt, dist_pt, idx_array[i], |
|
heap_distances, heap_indices) |
|
|
|
#------------------------------------------------------------ |
|
# Case 3: Node is not a leaf. Recursively query subnodes |
|
# starting with the closest |
|
else: |
|
i1 = 2 * i_node + 1 |
|
i2 = i1 + 1 |
|
sq_dist_LB_1 = min_rdist(node_centroids, |
|
node_radius, |
|
i1, X, i_pt) |
|
sq_dist_LB_2 = min_rdist(node_centroids, |
|
node_radius, |
|
i2, X, i_pt) |
|
|
|
# recursively query subnodes |
|
if sq_dist_LB_1 <= sq_dist_LB_2: |
|
_query_recursive(i1, X, i_pt, heap_distances, |
|
heap_indices, sq_dist_LB_1, |
|
data, idx_array, node_centroids, node_radius, |
|
node_is_leaf, node_idx_start, node_idx_end) |
|
_query_recursive(i2, X, i_pt, heap_distances, |
|
heap_indices, sq_dist_LB_2, |
|
data, idx_array, node_centroids, node_radius, |
|
node_is_leaf, node_idx_start, node_idx_end) |
|
else: |
|
_query_recursive(i2, X, i_pt, heap_distances, |
|
heap_indices, sq_dist_LB_2, |
|
data, idx_array, node_centroids, node_radius, |
|
node_is_leaf, node_idx_start, node_idx_end) |
|
_query_recursive(i1, X, i_pt, heap_distances, |
|
heap_indices, sq_dist_LB_1, |
|
data, idx_array, node_centroids, node_radius, |
|
node_is_leaf, node_idx_start, node_idx_end) |
|
|
|
|
|
@numba.jit(nopython=True, parallel=True) |
|
def _query_parallel(i_node, X, heap_distances, heap_indices, |
|
data, idx_array, node_centroids, node_radius, |
|
node_is_leaf, node_idx_start, node_idx_end): |
|
for i_pt in numba.prange(X.shape[0]): |
|
sq_dist_LB = min_rdist(node_centroids, node_radius, i_node, X, i_pt) |
|
_query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB, |
|
data, idx_array, node_centroids, node_radius, node_is_leaf, |
|
node_idx_start, node_idx_end) |
|
|
|
|
|
#---------------------------------------------------------------------- |
|
# The Ball Tree object |
|
class BallTree(object): |
|
def __init__(self, data, leaf_size=40): |
|
self.data = data |
|
self.leaf_size = leaf_size |
|
|
|
# validate data |
|
if self.data.size == 0: |
|
raise ValueError("X is an empty array") |
|
|
|
if leaf_size < 1: |
|
raise ValueError("leaf_size must be greater than or equal to 1") |
|
|
|
self.n_samples = self.data.shape[0] |
|
self.n_features = self.data.shape[1] |
|
|
|
# determine number of levels in the tree, and from this |
|
# the number of nodes in the tree. This results in leaf nodes |
|
# with numbers of points betweeen leaf_size and 2 * leaf_size |
|
self.n_levels = 1 + np.log2(max(1, ((self.n_samples - 1) |
|
// self.leaf_size))) |
|
self.n_nodes = int(2 ** self.n_levels) - 1 |
|
|
|
# allocate arrays for storage |
|
self.idx_array = np.arange(self.n_samples, dtype=int) |
|
self.node_radius = np.zeros(self.n_nodes, dtype=float) |
|
self.node_idx_start = np.zeros(self.n_nodes, dtype=int) |
|
self.node_idx_end = np.zeros(self.n_nodes, dtype=int) |
|
self.node_is_leaf = np.zeros(self.n_nodes, dtype=int) |
|
self.node_centroids = np.zeros((self.n_nodes, self.n_features), |
|
dtype=float) |
|
|
|
# Allocate tree-specific data from TreeBase |
|
_recursive_build(0, 0, self.n_samples, |
|
self.data, self.node_centroids, |
|
self.node_radius, self.idx_array, |
|
self.node_idx_start, self.node_idx_end, |
|
self.node_is_leaf, self.n_nodes, self.leaf_size) |
|
|
|
def query(self, X, k=1, sort_results=True): |
|
X = np.asarray(X, dtype=float) |
|
|
|
if X.shape[-1] != self.n_features: |
|
raise ValueError("query data dimension must " |
|
"match training data dimension") |
|
|
|
if self.data.shape[0] < k: |
|
raise ValueError("k must be less than or equal " |
|
"to the number of training points") |
|
|
|
# flatten X, and save original shape information |
|
Xshape = X.shape |
|
X = X.reshape((-1, self.data.shape[1])) |
|
|
|
# initialize heap for neighbors |
|
heap_distances, heap_indices = heap_create(X.shape[0], k) |
|
|
|
#for i in range(X.shape[0]): |
|
# sq_dist_LB = min_rdist(self.node_centroids, |
|
# self.node_radius, |
|
# 0, X, i) |
|
# _query_recursive(0, X, i, heap_distances, heap_indices, sq_dist_LB, |
|
# self.data, self.idx_array, self.node_centroids, |
|
# self.node_radius, self.node_is_leaf, |
|
# self.node_idx_start, self.node_idx_end) |
|
|
|
_query_parallel(0, X, heap_distances, heap_indices, |
|
self.data, self.idx_array, self.node_centroids, self.node_radius, |
|
self.node_is_leaf, self.node_idx_start, self.node_idx_end) |
|
|
|
distances, indices = heap_sort(heap_distances, heap_indices) |
|
distances = np.sqrt(distances) |
|
|
|
# deflatten results |
|
return (distances.reshape(Xshape[:-1] + (k,)), |
|
indices.reshape(Xshape[:-1] + (k,))) |
|
|
|
|
|
#---------------------------------------------------------------------- |
|
# Testing function |
|
|
|
def test_tree(N=1000, D=3, K=5, LS=40): |
|
from time import time |
|
from sklearn.neighbors import BallTree as skBallTree |
|
|
|
print("-------------------------------------------------------") |
|
print("Numba version: " + numba.__version__) |
|
|
|
rseed = np.random.randint(10000) |
|
print("-------------------------------------------------------") |
|
print("{0} neighbors of {1} points in {2} dimensions".format(K, N, D)) |
|
print("random seed = {0}".format(rseed)) |
|
np.random.seed(rseed) |
|
X = np.random.random((N, D)) |
|
|
|
# pre-run to jit compile the code |
|
BallTree(X, leaf_size=LS).query(X, K) |
|
|
|
t0 = time() |
|
bt1 = skBallTree(X, leaf_size=LS) |
|
t1 = time() |
|
dist1, ind1 = bt1.query(X, K) |
|
t2 = time() |
|
|
|
bt2 = BallTree(X, leaf_size=LS) |
|
t3 = time() |
|
dist2, ind2 = bt2.query(X, K) |
|
t4 = time() |
|
|
|
print("results match: {0} {1}".format(np.allclose(dist1, dist2), |
|
np.allclose(ind1, ind2))) |
|
print("") |
|
print("sklearn build: {0:.3g} sec".format(t1 - t0)) |
|
print("numba build : {0:.3g} sec".format(t3 - t2)) |
|
print("") |
|
print("sklearn query: {0:.3g} sec".format(t2 - t1)) |
|
print("numba query : {0:.3g} sec".format(t4 - t3)) |
|
|
|
|
|
if __name__ == '__main__': |
|
test_tree() |
Hmmmm ... not quite sure. Just so I'm understanding the code above (due to the formatting) ... you are looking to run this only on datasets that that have 20 or fewer rows? What is the dimensionality of each of these datasets as well as the leaf size you are attempting to use?