Created
November 8, 2021 14:22
-
-
Save arrieta/54e945e99735ca6fb33216678134ef2b to your computer and use it in GitHub Desktop.
Trivial k-d tree implementation in C++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Educational implementation of a k-d tree. | |
#include <algorithm> | |
#include <array> | |
#include <iostream> | |
#include <random> | |
#include <stdexcept> | |
#include <string> | |
#include <vector> | |
struct KDTree; | |
using Node = std::unique_ptr<KDTree>; | |
using Point = std::array<int, 3>; | |
using Points = std::vector<Point>; | |
struct KDTree { | |
Point p = {}; | |
std::unique_ptr<KDTree> lc = {}; | |
std::unique_ptr<KDTree> rc = {}; | |
}; | |
std::unique_ptr<KDTree> build(Points ps, unsigned int level = 0u) { | |
auto size = ps.size(); | |
if (size == 0u) { | |
return nullptr; | |
} | |
auto k = level % ps[0].size(); | |
auto pred = [k](auto p, auto q) { return p[k] < q[k]; }; | |
std::sort(ps.begin(), ps.end(), pred); | |
auto beg = ps.begin(); | |
auto mid = std::next(beg, size / 2u); | |
auto end = ps.end(); | |
auto node = std::make_unique<KDTree>(); | |
node->p = *mid; | |
node->lc = build({beg, mid}, level + 1u); | |
node->rc = build({std::next(mid, 1u), end}, level + 1u); | |
return node; | |
} | |
auto show(const KDTree* node, int level = 0) { | |
if (node == nullptr) { | |
return; | |
} | |
std::cout << std::string(4 * level, ' ') << "L" << level << "[" << node->p[0] | |
<< ", " << node->p[1] << ", " << node->p[2] << "]\n"; | |
show(node->lc.get(), level + 1); | |
show(node->rc.get(), level + 1); | |
} | |
int main(int argc, char* argv[]) { | |
auto rng = std::default_random_engine(0u); | |
auto dis = std::uniform_int_distribution<>(-100, 100); | |
const auto N = argc == 1 ? 100 : std::stoi(argv[1]); | |
Points ps(N); | |
for (auto k = 0u; k < ps.size(); ++k) { | |
ps[k] = {dis(rng), dis(rng), dis(rng)}; | |
} | |
show(build(ps).get()); | |
} |
Python visualization
import sys
import random
import numpy as np
import matplotlib.pyplot as plt
class Node:
def __init__(self, point, lc=None, rc=None):
self.point = point
self.lc = lc
self.rc = rc
def build(points, level=0):
if len(points) == 0:
return None
points = sorted(points, key=lambda p: p[level % 2])
m = len(points) // 2
lc = build(points[0:m], level + 1)
rc = build(points[m + 1:], level + 1)
return Node(points[m], lc, rc)
def plot_points(ax, points):
ax.scatter(points[:, 0], points[:, 1], s=3, color="k")
return ax
def plot_node(ax, node, level=0, xmin=0, xmax=1, ymin=0, ymax=1):
if node is None:
return
p = node.point
k = level % 2
if k == 0:
a = [p[0], p[0]]
b = [ymin, ymax]
plot_node(ax,
node.lc,
level + 1,
xmin=xmin,
xmax=p[0],
ymin=ymin,
ymax=ymax)
plot_node(ax,
node.rc,
level + 1,
xmin=p[0],
xmax=xmax,
ymin=ymin,
ymax=ymax)
ax.plot(a, b, "k-", linewidth=1)
else:
a = [xmin, xmax]
b = [p[1], p[1]]
ax.plot(a, b, "k-", linewidth=1)
plot_node(
ax,
node.lc,
level + 1,
xmin=xmin,
xmax=xmax,
ymin=ymin,
ymax=p[1],
)
plot_node(ax,
node.rc,
level + 1,
xmin=xmin,
xmax=xmax,
ymin=p[1],
ymax=ymax)
def main():
if len(sys.argv) == 1:
n = 100
else:
n = int(sys.argv[1])
points = np.random.rand(n, 2)
ax = plt.subplot(111)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plot_node(ax, build(points))
plot_points(ax, points)
plt.show()
if __name__ == "__main__":
main()
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Step-by-step construction view: