Last active
December 19, 2015 23:59
-
-
Save dergraf/6038721 to your computer and use it in GitHub Desktop.
incomplete kdtree implementation in erlang.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-module(kdtree). | |
-export([new/1, search/2, distance/2, test/0, test/3]). | |
new([]) -> []; | |
new(PointList) -> | |
K = size(lists:nth(1, PointList)), | |
kdtree(K, PointList, 0). | |
kdtree(_, [], _) -> []; | |
kdtree(K, PointList, Depth) -> | |
%% Select Dimension based on Depth so that Dimension cycles through all valid values | |
Dimension = (Depth rem K) + 1, | |
%% Sort PointList and choose median as pivot element | |
SortedPointList = lists:sort(fun(A, B) -> element(Dimension, A) =< element(Dimension, B) end, PointList), | |
L = length(PointList), | |
NewDepth = Depth + 1, | |
Median = (L div 2) + 1, | |
%% Create node and construct subtree | |
{lists:nth(Median, SortedPointList), | |
kdtree(K, lists:sublist(SortedPointList, Median-1), NewDepth), | |
kdtree(K, lists:sublist(SortedPointList, Median+1, L), NewDepth)}. | |
search(Tree, Point) when is_tuple(Point) -> | |
search(size(Point), Tree, Point, 0, undefined). | |
search(_, [], _, _, Best) -> Best; | |
search(K, {Location, _, _}=Tree, Point, Depth, undefined) -> | |
search(K, Tree, Point, Depth, Location); | |
search(K, {Location, Left, Right}, Point, Depth, Best) -> | |
NewBest1 = | |
case distance(Point, Location) < distance(Point, Best) of | |
true -> | |
Location; | |
false -> | |
Best | |
end, | |
Dimension = (Depth rem K) + 1, | |
case element(Dimension, Point) < element(Dimension, Location) of | |
true -> | |
NewBest2 = search(K, Left, Point, Depth + 1, NewBest1), | |
case distance_axis(Dimension, Location, Point) < distance(NewBest2, Point) of | |
true -> | |
search(K, Right, Point, Depth + 1, NewBest2); | |
false -> | |
NewBest2 | |
end; | |
false -> | |
NewBest2 = search(K, Right, Point, Depth + 1, NewBest1), | |
case distance_axis(Dimension, Location, Point) < distance(NewBest2, Point) of | |
true -> | |
search(K, Left, Point, Depth + 1, NewBest2); | |
false -> | |
NewBest2 | |
end | |
end. | |
distance(A, B) -> | |
%% squared distance between points A and B | |
distance(A, B, size(A), 0). | |
distance(_, _, 0, Distance) -> Distance; | |
distance(A, B, Dim, Distance) -> | |
distance(A, B, Dim-1, Distance + math:pow(element(Dim, A) - element(Dim, B), 2)). | |
distance_axis(Dim, Location, Point) -> | |
%% project point onto node axis | |
%% i.e. want to measure distance on axis orthogonal to current node's axis | |
distance(setelement(Dim, Point, element(Dim, Location)), Point). | |
test() -> | |
Dims = [1,2,3,4,5,6,7,8], | |
NrOfItems = [10,100,1000,10000], | |
[test(D, N, 100) || D <- Dims, N <- NrOfItems]. | |
test(Dim, N, Iters) -> | |
RandomPoint = fun() -> | |
list_to_tuple([random:uniform(1000) || _ <- lists:seq(1, Dim)]) | |
end, | |
PointList = [RandomPoint() || _ <- lists:seq(1,N)], | |
{TimeTreeSetup, Tree} = timer:tc(fun new/1, [PointList]), | |
%% find nearest neighbour linearly, | |
LinearSearch = fun(P) -> | |
lists:foldl(fun(Point, CurrentClosest) -> | |
case distance(P, Point) < distance(P, CurrentClosest) of | |
true -> Point; | |
false -> CurrentClosest | |
end | |
end, lists:nth(1, PointList), lists:sort(PointList)) | |
end, | |
Results = lists:map(fun(_) -> | |
RandomSample = RandomPoint(), | |
{TimeLinear, P1} = timer:tc(LinearSearch, [RandomSample]), | |
{TimeKDTree, P2} = timer:tc(fun search/2, [Tree, RandomSample]), | |
D1 = distance(P1, RandomSample), | |
D1 = distance(P2, RandomSample), %% we get a badmatch if the two returned neighbours don't have the same distance | |
{TimeLinear, TimeKDTree} | |
end, lists:seq(1, Iters)), | |
{LinearTime, KDTreeTime} = lists:unzip(Results), | |
TotalTimeLinear = lists:sum(LinearTime) / Iters, | |
TotalTimeKDTree = lists:sum(KDTreeTime) / Iters, | |
io:format("~p-dimensional, ~p items, buildtime ~pus, NNS avg time linear ~pus, NNS avg time kdtree ~pus~n", [Dim, N, TimeTreeSetup,TotalTimeLinear, TotalTimeKDTree]). |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment