Last active
August 29, 2015 14:04
-
-
Save fsmv/311ba8bb550bd6d36851 to your computer and use it in GitHub Desktop.
A K-D Tree implementation which preforms a nearest neighbor seach on any point type and returns the distance between them in any type using a user-defined distance function.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* A K-D Tree implementation which preforms a nearest neighbor seach on any point type and returns the | |
* distance between them in any type using a user-defined distance function. | |
* | |
* Rather than returning the actual nearest object this class returns the index of the nearest object in | |
* the original vector that was passed in. | |
* | |
* Provided under the MIT License | |
* Copyright (c) 2014 Andrew Kallmeyer <[email protected]> | |
* | |
* Permission is hereby granted, free of charge, to any person obtaining a copy | |
* of this software and associated documentation files (the "Software"), to deal | |
* in the Software without restriction, including without limitation the rights | |
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
* copies of the Software, and to permit persons to whom the Software is | |
* furnished to do so, subject to the following conditions: | |
* | |
* The above copyright notice and this permission notice shall be included in | |
* all copies or substantial portions of the Software. | |
* | |
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | |
* THE SOFTWARE. | |
*/ | |
#pragma once | |
#include <vector> | |
struct Node { | |
int index; | |
Node* left; | |
Node* right; | |
Node(int index) : index(index) {} | |
bool isLeaf() { return left == nullptr && right == nullptr; } | |
}; | |
template <typename PointType, typename DistType = double> | |
class KDTree { | |
public: | |
KDTree(DistType (*getDimVal)(const PointType&, int dim), int dimension = 2); | |
~KDTree(void); | |
/* | |
* add all entries in the array to the data structure | |
* | |
* Note: If two items have the same coordinates, then do not add the new item | |
* that has the same coordinates as another item. | |
*/ | |
void build(const std::vector<PointType> &c); | |
/* | |
* Return a pointer to the entry that is closest to the given coordinates. | |
*/ | |
int getNearest(const PointType &point, DistType *dist) const; | |
private: | |
Node* root; | |
std::vector<int> dataIndicies; | |
DistType getDist(const PointType &lhs, const PointType &rhs) const; | |
bool comp(int dim, const PointType &lhs, const PointType &rhs) const; | |
DistType (*getDimVal)(const PointType&, int dim); | |
int dimension; | |
void destroy(Node* root); | |
/* | |
* Recursively builds a kd-tree, n is the length of c | |
*/ | |
Node* build(int start, int n, int depth); | |
/* | |
* Recursively finds the nearest neighbor to a given point (x, y) root is the root of the kd-tree to search and currResult is the current closest (arbitrary initial choice) | |
*/ | |
int getNearest(const PointType &point, Node* node, int best, int depth = 0) const; | |
}; | |
//======= Implementation ======= | |
#include <cmath> | |
#include <algorithm> | |
#include <functional> | |
#include <iterator> | |
template<typename T, typename V> | |
KDTree<T, V>::KDTree(V (*getDimVal)(const T&, int dim), int dimension) : getDimVal(getDimVal), dimension(dimension) {} | |
template<typename T, typename V> | |
KDTree<T, V>::~KDTree(void) { | |
if(root != nullptr) { | |
destroy(root); | |
} | |
} | |
template<typename T, typename V> | |
void KDTree<T, V>::destroy(Node* root) { | |
if(root->left != nullptr){ | |
destroy(root->left); | |
} | |
if(root->right != nullptr){ | |
destroy(root->right); | |
} | |
delete root; | |
} | |
template<typename T, typename V> | |
void KDTree<T, V>::build(const std::vector<T> &c) { | |
data = c; | |
dataIndicies.reserve(data.size()); | |
for(unsigned int i = 0; i < data.size(); ++i) { | |
dataIndicies.push_back(i); | |
} | |
root = build(0, c.size(), 0); | |
} | |
template<typename T, typename V> | |
Node* KDTree<T, V>::build(int start, int n, int depth) { | |
Node* result; | |
if(n == 0) { | |
return nullptr; | |
}else if(n == 1) { | |
result = new Node(dataIndicies[start]); | |
result->left = nullptr; | |
result->right = nullptr; | |
return result; | |
} | |
std::sort(dataIndicies.begin() + start, dataIndicies.begin() + start + n, [&](int lhs, int rhs){ return comp(depth % dimension, data[lhs], data[rhs]); }); | |
const int halfLength = n/2; | |
result = new Node(dataIndicies[start + halfLength]); | |
result->left = build(start, halfLength, depth + 1); | |
result->right = build(start + halfLength + 1, (n % 2 == 0 ? -1 : 0) + halfLength, depth + 1); | |
return result; | |
} | |
template<typename T, typename V> | |
int KDTree<T, V>::getNearest(const T &point, V *dist) const { | |
int result = getNearest(point, root, root->index); | |
*dist = (V) std::sqrt((double) getDist(point, data[result])); | |
return result; | |
} | |
template<typename T, typename V> | |
int KDTree<T, V>::getNearest(const T &point, Node* node, int best, int depth) const { | |
bool leftOf = comp(depth % dimension, point, data[node->index]); | |
Node *childNear = leftOf ? node->left : node->right; | |
Node *childFar = leftOf ? node->right : node->left; | |
if(getDist(point, data[node->index]) < getDist(point, data[best])) { | |
best = node->index; | |
} | |
if(childNear != nullptr) { | |
best = getNearest(point, childNear, best, depth+1); | |
} | |
V axisDist = getDimVal(point, depth % dimension) - getDimVal(data[node->index], depth % dimension); | |
axisDist *= axisDist; //getDist is a squared distance to save time, this needs to be squared also | |
if(axisDist <= getDist(point, data[best])) { | |
if(childFar != nullptr) { | |
best = getNearest(point, childFar, best, depth+1); | |
} | |
} | |
return best; | |
} | |
template<typename T, typename V> | |
V KDTree<T,V>::getDist(const T &lhs, const T &rhs) const { | |
V sum = (V) 0; | |
for(int i = 0; i < dimension; ++i) { | |
V val = getDimVal(lhs, i) - getDimVal(rhs, i); | |
sum += val * val; | |
} | |
return sum; | |
} | |
template<typename T, typename V> | |
bool KDTree<T,V>::comp(int dim, const T &lhs, const T &rhs) const { | |
return getDimVal(lhs, dim) < getDimVal(rhs, dim); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment