Created
July 30, 2017 00:25
-
-
Save zhenghaoz/7f2d7b9303a3ceca14aab112cd16a46e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
count = 100; | |
dimen = 100; | |
test = 100; | |
function ret = distance(a, b) | |
ret = sqrt(sum((a.-b).^2,2)); | |
end | |
function ret = find_k_nearest_naive(x, p, k) | |
[val, r] = sort(distance(x,p)); | |
ret = r(1:k); | |
end | |
function ret = kdtree_build_recur(x, r, d) | |
count = length(r); | |
dimen = size(x, 2); | |
if (count == 1) | |
ret = struct('point', r(1), 'dimen', d); | |
else | |
mid = ceil(count / 2); | |
ret = struct('point', r(mid), 'dimen', d); | |
d = mod(d,dimen)+1; | |
% Build left sub tree | |
if (mid > 1) | |
left = r(1:mid-1); | |
leftxd = x(left,d); | |
[val, leftrr] = sort(leftxd); | |
leftr = left(leftrr); | |
ret.left = kdtree_build_recur(x, leftr, d); | |
end | |
% Build right sub tree | |
if (count > mid) | |
right = r(mid+1:count); | |
rightxd = x(right,d); | |
[val, rightrr] = sort(rightxd); | |
rightr = right(rightrr); | |
ret.right = kdtree_build_recur(x, rightr, d); | |
end | |
end | |
end | |
function ret = kdtree_build(x) | |
[val, r] = sort(x(:,1)); | |
ret = struct('data',x,'root', kdtree_build_recur(x,r,1)); | |
end | |
function ret = kdtree_cand_fartest(x, p, cand) | |
[val, index] = max(distance(x, p)(cand)); | |
ret = cand(index); | |
end | |
function ret = kdtree_cand_insert(x, p, cand, k, point) | |
if (length(cand) < k) | |
ret = [cand; point]; | |
else | |
fartest = kdtree_cand_fartest(x, p, cand); | |
cand(find(cand == fartest)) = point; | |
ret = cand; | |
end | |
end | |
function ret = kdtree_find_recur(x, node, p, ret, k) | |
point = node.point; | |
d = node.dimen; | |
if (x(point,d) > p(d)) | |
% Search in left sub tree | |
if (isfield(node, 'left')) | |
ret = kdtree_find_recur(x, node.left, p, ret, k); | |
end | |
% Add current point if neccessary | |
fartest = kdtree_cand_fartest(x, p, ret); | |
if (length(ret) < k || distance(x(point,:),p) < distance(x(fartest,:),p)) | |
ret = kdtree_cand_insert(x, p, ret, k, point); | |
end | |
% Search in right sub tree if neccessary | |
fartest = kdtree_cand_fartest(x, p, ret); | |
radius = distance(x(fartest,:),p); | |
if (isfield(node, 'right') && (length(ret) < k || p(d) + radius > x(point,d))) | |
ret = kdtree_find_recur(x, node.right, p, ret, k); | |
end | |
else | |
% Search in right sub tree | |
if (isfield(node, 'right')) | |
ret = kdtree_find_recur(x, node.right, p, ret, k); | |
end | |
% Add current point if neccessary | |
fartest = kdtree_cand_fartest(x, p, ret); | |
if (length(ret) < k || distance(x(point,:),p) < distance(x(fartest,:),p)) | |
ret = kdtree_cand_insert(x, p, ret, k, point); | |
end | |
% Search in left sub tree if neccessary | |
fartest = kdtree_cand_fartest(x, p, ret); | |
radius = distance(x(fartest,:),p); | |
if (isfield(node, 'left') && (length(ret) < k || p(d) - radius <= x(point,d))) | |
ret = kdtree_find_recur(x, node.left, p, ret, k); | |
end | |
end | |
end | |
function ret = kdtree_find(tree, p, k) | |
x = tree.data; | |
root = tree.root; | |
ret = kdtree_find_recur(x, root, p, [], k); | |
end | |
correct = 0; | |
for i = 1:test | |
x = rand(count, dimen); | |
y = rand(1, dimen); | |
k = ceil(rand(1,1)*count); | |
ret = kdtree_build(x); | |
a = sort(kdtree_find(ret, y, k)); | |
b = sort(find_k_nearest_naive(x, y, k)); | |
correct += sum(a == b) == k; | |
end | |
printf('correct:%d/%d\n', correct, test); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment