Skip to content

Instantly share code, notes, and snippets.

@zhenghaoz
Created July 30, 2017 00:25
Show Gist options
  • Save zhenghaoz/7f2d7b9303a3ceca14aab112cd16a46e to your computer and use it in GitHub Desktop.
Save zhenghaoz/7f2d7b9303a3ceca14aab112cd16a46e to your computer and use it in GitHub Desktop.
count = 100;
dimen = 100;
test = 100;
function ret = distance(a, b)
ret = sqrt(sum((a.-b).^2,2));
end
function ret = find_k_nearest_naive(x, p, k)
[val, r] = sort(distance(x,p));
ret = r(1:k);
end
function ret = kdtree_build_recur(x, r, d)
count = length(r);
dimen = size(x, 2);
if (count == 1)
ret = struct('point', r(1), 'dimen', d);
else
mid = ceil(count / 2);
ret = struct('point', r(mid), 'dimen', d);
d = mod(d,dimen)+1;
% Build left sub tree
if (mid > 1)
left = r(1:mid-1);
leftxd = x(left,d);
[val, leftrr] = sort(leftxd);
leftr = left(leftrr);
ret.left = kdtree_build_recur(x, leftr, d);
end
% Build right sub tree
if (count > mid)
right = r(mid+1:count);
rightxd = x(right,d);
[val, rightrr] = sort(rightxd);
rightr = right(rightrr);
ret.right = kdtree_build_recur(x, rightr, d);
end
end
end
function ret = kdtree_build(x)
[val, r] = sort(x(:,1));
ret = struct('data',x,'root', kdtree_build_recur(x,r,1));
end
function ret = kdtree_cand_fartest(x, p, cand)
[val, index] = max(distance(x, p)(cand));
ret = cand(index);
end
function ret = kdtree_cand_insert(x, p, cand, k, point)
if (length(cand) < k)
ret = [cand; point];
else
fartest = kdtree_cand_fartest(x, p, cand);
cand(find(cand == fartest)) = point;
ret = cand;
end
end
function ret = kdtree_find_recur(x, node, p, ret, k)
point = node.point;
d = node.dimen;
if (x(point,d) > p(d))
% Search in left sub tree
if (isfield(node, 'left'))
ret = kdtree_find_recur(x, node.left, p, ret, k);
end
% Add current point if neccessary
fartest = kdtree_cand_fartest(x, p, ret);
if (length(ret) < k || distance(x(point,:),p) < distance(x(fartest,:),p))
ret = kdtree_cand_insert(x, p, ret, k, point);
end
% Search in right sub tree if neccessary
fartest = kdtree_cand_fartest(x, p, ret);
radius = distance(x(fartest,:),p);
if (isfield(node, 'right') && (length(ret) < k || p(d) + radius > x(point,d)))
ret = kdtree_find_recur(x, node.right, p, ret, k);
end
else
% Search in right sub tree
if (isfield(node, 'right'))
ret = kdtree_find_recur(x, node.right, p, ret, k);
end
% Add current point if neccessary
fartest = kdtree_cand_fartest(x, p, ret);
if (length(ret) < k || distance(x(point,:),p) < distance(x(fartest,:),p))
ret = kdtree_cand_insert(x, p, ret, k, point);
end
% Search in left sub tree if neccessary
fartest = kdtree_cand_fartest(x, p, ret);
radius = distance(x(fartest,:),p);
if (isfield(node, 'left') && (length(ret) < k || p(d) - radius <= x(point,d)))
ret = kdtree_find_recur(x, node.left, p, ret, k);
end
end
end
function ret = kdtree_find(tree, p, k)
x = tree.data;
root = tree.root;
ret = kdtree_find_recur(x, root, p, [], k);
end
correct = 0;
for i = 1:test
x = rand(count, dimen);
y = rand(1, dimen);
k = ceil(rand(1,1)*count);
ret = kdtree_build(x);
a = sort(kdtree_find(ret, y, k));
b = sort(find_k_nearest_naive(x, y, k));
correct += sum(a == b) == k;
end
printf('correct:%d/%d\n', correct, test);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment