Created
March 10, 2011 00:15
-
-
Save tompaton/863301 to your computer and use it in GitHub Desktop.
Python kd-tree spatial index and nearest neighbour search
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# kd-tree index and nearest neighbour search | |
# includes doctests, run with: python -m doctest kdtree.py | |
class KDTree(object): | |
""" | |
kd-tree spatial index and nearest neighbour search | |
http://en.wikipedia.org/wiki/Kd-tree | |
""" | |
def __init__(self, point_list, _depth=0): | |
""" | |
Initialize kd-tree index with points. | |
>>> KDTree([]) | |
None | |
>>> KDTree([(1,1)]) | |
(0, (1, 1), None, None) | |
>>> KDTree([(1,1),(2,2)]) | |
(0, (2, 2), (1, (1, 1), None, None), None) | |
>>> KDTree([(1,1),(2,2),(3,3)]) | |
(0, (2, 2), (1, (1, 1), None, None), (1, (3, 3), None, None)) | |
""" | |
if point_list: | |
# Select axis based on depth so that axis cycles through all valid values | |
self.axis = _depth % len(point_list[0]) | |
# Sort point list and choose median as pivot element | |
point_list = sorted(point_list, key=lambda point: point[self.axis]) | |
median = len(point_list) // 2 # choose median | |
# Create node and construct subtrees | |
self.location = point_list[median] | |
self.child_left = KDTree(point_list[:median], _depth + 1) | |
self.child_right = KDTree(point_list[median + 1:], _depth + 1) | |
else: | |
self.axis = 0 | |
self.location = None | |
self.child_left = None | |
self.child_right = None | |
def closest_point(self, point, _best=None): | |
""" | |
Efficient recursive search for nearest neighbour to point | |
>>> t = KDTree([(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)]) | |
>>> t | |
(0, (7, 2), (1, (5, 4), (0, (2, 3), None, None), (0, (4, 7), None, None)), (1, (9, 6), (0, (8, 1), None, None), None)) | |
>>> t.closest_point( (7,2) ) | |
(7, 2) | |
>>> t.closest_point( (8,1) ) | |
(8, 1) | |
>>> t.closest_point( (1,1) ) | |
(2, 3) | |
>>> t.closest_point( (5,5) ) | |
(5, 4) | |
""" | |
if self.location is None: | |
return _best | |
if _best is None: | |
_best = self.location | |
# consider the current node | |
if distance(self.location, point) < distance(_best, point): | |
_best = self.location | |
# search the near branch | |
_best = self._child_near(point).closest_point(point, _best) | |
# search the away branch - maybe | |
if self._distance_axis(point) < distance(_best, point): | |
_best = self._child_away(point).closest_point(point, _best) | |
return _best | |
# internal methods | |
def __repr__(self): | |
""" | |
Simple representation for doctests | |
""" | |
if self.location: | |
return "(%d, %s, %s, %s)" % (self.axis, repr(self.location), repr(self.child_left), repr(self.child_right)) | |
else: | |
return "None" | |
def _distance_axis(self, point): | |
""" | |
Squared distance from current node axis to point | |
>>> KDTree([(1,1)])._distance_axis((2,3)) | |
1 | |
>>> KDTree([(1,1),(2,2)]).child_left._distance_axis((2,3)) | |
4 | |
""" | |
# project point onto node axis | |
# i.e. want to measure distance on axis orthogonal to current node's axis | |
axis_point = list(point) | |
axis_point[self.axis] = self.location[self.axis] | |
return distance(tuple(axis_point), point) | |
def _child_near(self, point): | |
""" | |
Either left or right child, whichever is closest to the point | |
""" | |
if point[self.axis] < self.location[self.axis]: | |
return self.child_left | |
else: | |
return self.child_right | |
def _child_away(self, point): | |
""" | |
Either left or right child, whichever is furthest from the point | |
""" | |
if self._child_near(point) is self.child_left: | |
return self.child_right | |
else: | |
return self.child_left | |
# helper function | |
def distance(a, b): | |
""" | |
Squared distance between points a & b | |
""" | |
return (a[0]-b[0])**2 + (a[1]-b[1])**2 |
Hi There,
Thanks for the code. It really helped me. One thing though is that the closest_point function always returns the same original point passed to it .. any clues to why is that happening ?
thanks for the code
it s really helpful
To go beyond 2 dimensions, simply change the distance calculation to
def distance(a, b):
"""
Squared distance between points a & b
"""
return sum((x-y)**2 for x, y in zip(a, b))
it's really helpful, thnks very much
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The distance function works only if the points in kd-tree are of 2 dimensions. However, the code does not assume the points in point_list to be of 2 dim. Simply changing the global distance function would fix the issue.