Created
December 5, 2022 06:46
-
-
Save 99991/08bcb341bd5a47170908d8c762d559c9 to your computer and use it in GitHub Desktop.
KD-Tree query using PyOpenCL to find k nearest neighbors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pyopencl as cl | |
from pymatting import KDTree | |
import pymatting | |
import time | |
source = """ | |
typedef int int32_t; | |
typedef long int64_t; | |
__kernel void _tree_query_cl( | |
__global const int64_t *i0_inds, | |
__global const int64_t *i1_inds, | |
__global const int64_t *less_inds, | |
__global const int64_t *more_inds, | |
__global const int64_t *split_dims, | |
__global const float *bounds, | |
__global const float *split_values, | |
__global const float *points, | |
__global const float *query_points, | |
__global int64_t *out_indices, | |
__global float *out_distances, | |
int32_t k, | |
int32_t n_query, | |
int32_t dimension | |
){ | |
int i_query = get_global_id(0); | |
if (i_query >= n_query) return; | |
__global const float *query_point = query_points + i_query * dimension; | |
__global float *distances = out_distances + i_query * k; | |
__global int64_t *indices = out_indices + i_query * k; | |
int64_t stack[100]; | |
int n_neighbors = 0; | |
stack[0] = 0; | |
int stack_size = 1; | |
// While there are nodes to visit | |
while (stack_size > 0){ | |
stack_size--; | |
int i_node = stack[stack_size]; | |
// If we found enough neighbors | |
if (n_neighbors >= k){ | |
float dist = 0.0f; | |
for (int d = 0; d < dimension; d++){ | |
float p = query_point[d]; | |
// bounds shape is (n_data, 2, dimension) | |
float lower_bound = bounds[i_node * dimension * 2 + 0 * dimension + d]; | |
float upper_bound = bounds[i_node * dimension * 2 + 1 * dimension + d]; | |
float dp = p - fmax(lower_bound, fmin(p, upper_bound)); | |
dist += dp * dp; | |
} | |
// Do nothing with this node if all points we have found so far | |
// are closer than the bounding box of the node. | |
if (dist > distances[n_neighbors - 1]){ | |
continue; | |
} | |
} | |
// If we are at a leaf | |
if (split_dims[i_node] == -1){ | |
// For each point in leaf node | |
for (int i = i0_inds[i_node]; i < i1_inds[i_node]; i++){ | |
float distance = 0.0f; | |
for (int d = 0; d < dimension; d++){ | |
float dd = query_point[d] - points[i * dimension + d]; | |
distance += dd * dd; | |
} | |
// Find insert position | |
int insert_pos = n_neighbors; | |
for (int j = n_neighbors - 1; j >= 0; j--){ | |
if (distances[j] > distance) insert_pos = j; | |
else break; | |
} | |
if (insert_pos < k){ | |
// Move [insert_pos:k-1] one to the right to make space | |
int j = k - 1; | |
if (j > n_neighbors) j = n_neighbors; | |
for (; j > insert_pos; j--){ | |
distances[j] = distances[j - 1]; | |
indices[j] = indices[j - 1]; | |
} | |
// Insert new neighbors | |
indices[insert_pos] = i; | |
distances[insert_pos] = distance; | |
n_neighbors++; | |
if (n_neighbors > k) n_neighbors = k; | |
} | |
} | |
}else{ | |
// Descent to child nodes | |
int64_t split_dim = split_dims[i_node]; | |
int64_t less = less_inds[i_node]; | |
int64_t more = more_inds[i_node]; | |
if (query_point[split_dim] < split_values[i_node]){ | |
stack[stack_size++] = more; | |
stack[stack_size++] = less; | |
}else{ | |
stack[stack_size++] = less; | |
stack[stack_size++] = more; | |
} | |
} | |
} | |
} | |
""" | |
platform = cl.get_platforms()[0] | |
devices = platform.get_devices(cl.device_type.GPU) | |
if not devices: | |
print("WARNING: OpenCL could not find any GPU device. Trying other devices.") | |
devices = platform.get_devices(cl.device_type.ALL) | |
assert len(devices) > 0, "Could not find any OpenCL-capable device" | |
device = devices[0] | |
context = cl.Context([device]) | |
queue = cl.CommandQueue(context) | |
program = cl.Program(context, source).build() | |
def upload(array): | |
hostbuf = array.astype(array.dtype).flatten() | |
return cl.Buffer( | |
context, | |
cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR, | |
hostbuf=hostbuf, | |
) | |
def download(device_buf, shape, dtype): | |
host_buf = np.empty(shape, dtype=dtype) | |
cl.enqueue_copy(queue, host_buf, device_buf) | |
return host_buf.reshape(shape) | |
def tree_query_cl(tree, query_points, k): | |
assert query_points.dtype == np.float32 | |
assert query_points.shape[1] == tree.shuffled_data_points.shape[1] | |
n_query, dimension = query_points.shape | |
squared_distances = np.empty((n_query, k), np.float32) | |
indices = np.empty((n_query, k), np.int64) | |
gpu_i0_inds = upload(tree.i0_inds) | |
gpu_i1_inds = upload(tree.i1_inds) | |
gpu_less_inds = upload(tree.less_inds) | |
gpu_more_inds = upload(tree.more_inds) | |
gpu_split_dims = upload(tree.split_dims) | |
gpu_bounds = upload(tree.bounds) | |
gpu_split_values = upload(tree.split_values) | |
gpu_shuffled_data_points = upload(tree.shuffled_data_points) | |
gpu_query_points = upload(query_points) | |
gpu_indices = upload(indices) | |
gpu_squared_distances = upload(squared_distances) | |
program._tree_query_cl( | |
queue, | |
(n_query,), | |
None, | |
gpu_i0_inds, | |
gpu_i1_inds, | |
gpu_less_inds, | |
gpu_more_inds, | |
gpu_split_dims, | |
gpu_bounds, | |
gpu_split_values, | |
gpu_shuffled_data_points, | |
gpu_query_points, | |
gpu_indices, | |
gpu_squared_distances, | |
np.int32(k), | |
np.int32(n_query), | |
np.int32(dimension), | |
) | |
indices = download(gpu_indices, (n_query, k), np.int64) | |
squared_distances = download(gpu_squared_distances, (n_query, k), np.float32) | |
for buf in [ | |
gpu_i0_inds, gpu_i1_inds, gpu_less_inds, gpu_more_inds, | |
gpu_split_dims, gpu_bounds, gpu_split_values, gpu_shuffled_data_points, | |
gpu_query_points, gpu_indices, gpu_squared_distances, | |
]: | |
buf.release() | |
indices = tree.shuffled_indices[indices] | |
distances = np.sqrt(squared_distances) | |
return distances, indices | |
def main(): | |
np.random.seed(0) | |
k = 20 | |
data_points = np.random.rand(256 * 512, 3).astype(np.float32) | |
query_points = np.random.rand(256 * 512, 3).astype(np.float32) | |
t = time.perf_counter() | |
tree = pymatting.KDTree(data_points) | |
dt = time.perf_counter() - t | |
print("\nbuild KD tree:", dt * 1000, "ms\n") | |
for _ in range(10): | |
t = time.perf_counter() | |
expected_distances, expected_indices = tree.query(query_points, k=k) | |
dt = time.perf_counter() - t | |
print("numba", dt * 1000, "ms") | |
t = time.perf_counter() | |
distances, indices = tree_query_cl(tree, query_points, k=k) | |
dt = time.perf_counter() - t | |
print("opencl", dt * 1000, "ms\n") | |
mean_squared_error = np.mean(np.square(distances - expected_distances)) | |
assert mean_squared_error < 1e-10 | |
assert np.allclose(distances, expected_distances) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment