-
-
Save brendancol/a3dd4a35ecd94660411112999923d561 to your computer and use it in GitHub Desktop.
from functools import partial | |
import numpy as np | |
from scipy.spatial import cKDTree | |
from dask.delayed import delayed | |
import dask.bag as db | |
class DaskKDTree(object): | |
""" | |
Usage Example: | |
-------------- | |
from dask.distributed import Client | |
client = Client('52.91.203.58:8786') | |
tree = DaskKDTree(client, leafsize=1000) | |
tree.load_random(num_points=int(1e8), chunk_size=int(1e6)) | |
# find all points within 10km of Bourbon Street | |
bourbon_street = (-10026195.958134, 3498018.476606) | |
radius = 10000 # meters | |
result = tree.query_ball_point(x=bourbon_street, r=radius) | |
""" | |
def __init__(self, client, leafsize): | |
self.client = client | |
self.leafsize = leafsize | |
self.trees = [] | |
def load_random(self, num_points=int(1e6), chunk_size=300): | |
parts = int(num_points / chunk_size) | |
self.trees = [delayed(DaskKDTree._run_load_random)(int(chunk_size), leafsize=self.leafsize) for f in range(parts)] | |
self.trees = self.client.persist(self.trees) | |
@staticmethod | |
def _run_load_random(count, leafsize): | |
xs = np.random.uniform(int(-20e6), int(20e6), count) | |
ys = np.random.uniform(int(-20e6), int(20e6), count) | |
points = np.dstack((xs, ys))[0, :] | |
return cKDTree(points, leafsize=leafsize) | |
def query_ball_point(self, **kwargs): | |
nearest = [delayed(DaskKDTree._run_query_ball_point)(d, kwargs) for d in self.trees] | |
b = db.from_delayed(nearest) | |
return b.compute() | |
@staticmethod | |
def _run_query_ball_point(tree, query_info): | |
indices = tree.query_ball_point(**query_info) | |
return tree.data[indices] |
what's the approximate time to complete the nearest search on a computer with Ram 16GB and i7 processor. For single point it was running for more than 15 minutes. Is there something I might be doing wrong? I want to run this query_ball_point for millions of points
Facing this error, while replicating your example:
File "C:\Anaconda3\lib\site-packages\dask\bag\core.py", line 1460, in reify
if seq and isinstance(seq[0], Iterator):
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
For small number of kdtrees, this is an interesting approach.
For truly distributed KDTrees, I recently read a paper "Highly Parallel Fast KD‐tree Construction for Interactive Ray Tracing of Dynamic Scenes" [1]
that looks like it might be a more optimal way to do this in parallel.
[1]
: https://onlinelibrary.wiley.com/doi/pdf/10.1111/j.1467-8659.2007.01062.x
My understanding here is that the approach is to have a flat collection of single-machine kd-trees and, when we need to query something we check them all. Is this right?
I would want to know more about how this is likely to be used, and other possible options. My guess is that people have analyzed other possible arrangements, such as where we partition the dataset ahead of time so that different sections of a large tree are on different machines. I suspect that there is some tradeoff between the two (or more) possible cases.
Small style feedback, you might want to avoid use of partial here:
Instead doing something like the following:
This is both somewhat more direct and easier for Dask to serialize.