Created
November 8, 2021 14:22
-
-
Save arrieta/54e945e99735ca6fb33216678134ef2b to your computer and use it in GitHub Desktop.
Trivial k-d tree implementation in C++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Educational implementation of a k-d tree. | |
#include <algorithm> | |
#include <array> | |
#include <iostream> | |
#include <random> | |
#include <stdexcept> | |
#include <string> | |
#include <vector> | |
struct KDTree; | |
using Node = std::unique_ptr<KDTree>; | |
using Point = std::array<int, 3>; | |
using Points = std::vector<Point>; | |
struct KDTree { | |
Point p = {}; | |
std::unique_ptr<KDTree> lc = {}; | |
std::unique_ptr<KDTree> rc = {}; | |
}; | |
std::unique_ptr<KDTree> build(Points ps, unsigned int level = 0u) { | |
auto size = ps.size(); | |
if (size == 0u) { | |
return nullptr; | |
} | |
auto k = level % ps[0].size(); | |
auto pred = [k](auto p, auto q) { return p[k] < q[k]; }; | |
std::sort(ps.begin(), ps.end(), pred); | |
auto beg = ps.begin(); | |
auto mid = std::next(beg, size / 2u); | |
auto end = ps.end(); | |
auto node = std::make_unique<KDTree>(); | |
node->p = *mid; | |
node->lc = build({beg, mid}, level + 1u); | |
node->rc = build({std::next(mid, 1u), end}, level + 1u); | |
return node; | |
} | |
auto show(const KDTree* node, int level = 0) { | |
if (node == nullptr) { | |
return; | |
} | |
std::cout << std::string(4 * level, ' ') << "L" << level << "[" << node->p[0] | |
<< ", " << node->p[1] << ", " << node->p[2] << "]\n"; | |
show(node->lc.get(), level + 1); | |
show(node->rc.get(), level + 1); | |
} | |
int main(int argc, char* argv[]) { | |
auto rng = std::default_random_engine(0u); | |
auto dis = std::uniform_int_distribution<>(-100, 100); | |
const auto N = argc == 1 ? 100 : std::stoi(argv[1]); | |
Points ps(N); | |
for (auto k = 0u; k < ps.size(); ++k) { | |
ps[k] = {dis(rng), dis(rng), dis(rng)}; | |
} | |
show(build(ps).get()); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Python visualization