Created
January 18, 2018 15:33
-
-
Save chengluyu/095a8dff5e43bff1a9941323d158e013 to your computer and use it in GitHub Desktop.
My Naïve Implement of K-d Tree
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cassert> | |
#include <chrono> | |
#include <iostream> | |
#include <iterator> | |
#include <type_traits> | |
#include <cmath> | |
#include <vector> | |
#include <random> | |
class Timer { | |
std::chrono::time_point<std::chrono::high_resolution_clock> start_, end_; | |
public: | |
inline void start() { | |
start_ = std::chrono::high_resolution_clock::now(); | |
} | |
inline void stop() { | |
end_ = std::chrono::high_resolution_clock::now(); | |
} | |
inline double elapsedSeconds() const { | |
std::chrono::duration<double> d = end_ - start_; | |
return d.count(); | |
} | |
}; | |
template <typename TotalOrderComparable> | |
struct Range { | |
TotalOrderComparable min; | |
TotalOrderComparable max; | |
Range(TotalOrderComparable mi, TotalOrderComparable mx) : min(mi), max(mx) { } | |
inline bool contains(TotalOrderComparable value) const { | |
return min < value && value < max; | |
} | |
inline bool overlap(const Range &rhs) { | |
return !(max < rhs.min || min > rhs.max); | |
} | |
static Range neighbourhood(TotalOrderComparable center, TotalOrderComparable radius) { | |
return Range{ center - radius, center + radius }; | |
} | |
}; | |
constexpr float square(float x) { | |
return x * x; | |
} | |
template <typename RealType> | |
RealType epsilon = 1e-7; | |
struct Vector2F { | |
float x, y; | |
using Axis = float Vector2F::*; | |
static const Axis axes[3]; | |
Vector2F(float x, float y) : x(x), y(y) { } | |
inline float dot(const Vector2F &rhs) const { | |
return x * rhs.x + y * rhs.y; | |
} | |
inline float distance(const Vector2F &rhs) const { | |
return std::sqrt(square(x - rhs.x) + square(y - rhs.y)); | |
} | |
inline bool operator== (const Vector2F &rhs) const { | |
return (x - rhs.x) < epsilon<float> && | |
(y - rhs.y) < epsilon<float>; | |
} | |
inline bool operator!= (const Vector2F &rhs) const { | |
return !(*this == rhs); | |
} | |
}; | |
std::ostream &operator<< (std::ostream &out, const Vector2F &vec) { | |
return out << '(' << vec.x << ", " << vec.y << ')'; | |
} | |
const Vector2F::Axis Vector2F::axes[3] = { &Vector2F::x, &Vector2F::y }; | |
struct KdTreeNode { | |
Vector2F data; | |
Vector2F::Axis axis; | |
KdTreeNode *left_child; | |
KdTreeNode *right_child; | |
float value; | |
Range<float> range = { 0.0f, 0.0f }; | |
KdTreeNode(const Vector2F &data, | |
Vector2F::Axis axis, | |
Range<float> range, | |
KdTreeNode *lc = nullptr, | |
KdTreeNode *rc = nullptr) | |
: data(data), axis(axis), left_child(lc), right_child(rc), value(data.*axis), range(range) { } | |
~KdTreeNode() { | |
delete left_child; | |
delete right_child; | |
} | |
static void dump(KdTreeNode *node, std::ostream &out, size_t depth = 0) { | |
for (size_t i = 0; i < depth; i++) | |
out << " "; | |
if (node) { | |
out << node->data << '\n'; | |
if (node->left_child || node->right_child) { | |
dump(node->left_child, out, depth + 1); | |
dump(node->right_child, out, depth + 1); | |
} | |
} else { | |
out << '*' << '\n'; | |
} | |
} | |
template <typename RandomAccessIt> | |
static KdTreeNode *build(RandomAccessIt begin, RandomAccessIt end, size_t depth = 0) { | |
if (begin == end) | |
return nullptr; | |
auto middle = begin + std::distance(begin, end) / 2; | |
auto axis = Vector2F::axes[depth % 2]; | |
auto comp = [axis](auto lhs, auto rhs) { | |
return lhs.*axis < rhs.*axis; | |
}; | |
std::nth_element(begin, middle, end, comp); | |
auto result = std::minmax_element(begin, end, comp); | |
return new KdTreeNode(*middle, | |
Vector2F::axes[depth % 2], | |
Range{ (*result.first).*axis, (*result.second).*axis }, | |
build(begin, middle, depth + 1), | |
build(middle + 1, end, depth + 1)); | |
} | |
}; | |
class KdTreeQuery { | |
Vector2F source_point_; | |
Vector2F *nearest_point_; | |
float min_distance_; | |
public: | |
explicit KdTreeQuery(const Vector2F &source) | |
: source_point_(source), | |
nearest_point_(nullptr), | |
min_distance_(std::numeric_limits<float>::max()) { } | |
void query(KdTreeNode *node) { | |
if (node == nullptr) | |
return; | |
auto distance = source_point_.distance(node->data); | |
if (distance < min_distance_) { | |
min_distance_ = distance; | |
nearest_point_ = &node->data; | |
} | |
if (source_point_.*node->axis < node->value) { | |
query(node->left_child); | |
if (source_point_.*node->axis + min_distance_ > node->value) | |
query(node->right_child); | |
} else { | |
query(node->right_child); | |
if (source_point_.*node->axis - min_distance_ < node->value) | |
query(node->left_child); | |
} | |
} | |
inline bool isNull() const { | |
return nearest_point_ == nullptr; | |
} | |
inline Vector2F result() const { | |
return *nearest_point_; | |
} | |
inline float distance() const { | |
return min_distance_; | |
} | |
}; | |
std::vector<Vector2F> generateRandomPoints(size_t n) { | |
std::random_device rd; | |
std::mt19937_64 engine(rd()); | |
std::uniform_real_distribution<float> dist{ 0, +100.0f }; | |
std::vector<Vector2F> points; | |
points.reserve(n); | |
for (size_t i = 0; i < n; i++) | |
points.emplace_back(dist(engine), dist(engine)); | |
return points; | |
} | |
Vector2F generateRandomPoint() { | |
std::random_device rd; | |
std::mt19937_64 engine(rd()); | |
std::uniform_real_distribution<float> dist{ 0, +100.0f }; | |
return Vector2F{ dist(engine), dist(engine) }; | |
} | |
int main() { | |
std::vector<Vector2F> points = generateRandomPoints(100); | |
Timer timer; | |
KdTreeNode *root = KdTreeNode::build(points.begin(), points.end()); | |
while (true) { | |
auto source = generateRandomPoint(); | |
KdTreeQuery query{source}; | |
query.query(root); | |
assert(!query.isNull()); | |
auto gt = *std::min_element(points.begin(), points.end(), [source](auto lhs, auto rhs) { | |
return source.distance(lhs) < source.distance(rhs); | |
}); | |
if (query.result() != gt) { | |
std::cout << "Test failed.\n"; | |
std::cout << "Source point is " << source << '\n'; | |
std::cout << "Point find by K-d tree is " << query.result() << " while the truth is " << gt << '\n'; | |
std::copy(points.begin(), points.end(), std::ostream_iterator<Vector2F>(std::cout, ",\n")); | |
std::cout << "The tree:\n"; | |
KdTreeNode::dump(root, std::cout); | |
break; | |
} | |
} | |
delete root; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment