Skip to content

Instantly share code, notes, and snippets.

@danielkelshaw
Last active February 12, 2020 16:02
Show Gist options
  • Save danielkelshaw/de639cdcaabb972c10da3982f727a25f to your computer and use it in GitHub Desktop.
Save danielkelshaw/de639cdcaabb972c10da3982f727a25f to your computer and use it in GitHub Desktop.
KdTree Implementation
#include <vector>
#include <cmath>
struct Node {
std::vector<float> point;
int id;
Node* left;
Node* right;
Node(std::vector<float> arr, int setID)
: point(arr), id(setID), left(NULL), right(NULL)
{}
};
struct KdTree {
Node* root;
KdTree()
: root(NULL)
{}
void insertHelper(Node *&node, uint level, std::vector<float> point, int id) {
uint index = level % 3;
if (node == NULL) {
node = new Node(point, id);
}
else if (point[index] < node->point[index]) {
insertHelper(node->left, level + 1, point, id);
}
else {
insertHelper(node->right, level + 1, point, id);
}
}
void insertCloud(std::vector<std::vector<float>>* cloud) {
for (size_t i = 0; i < cloud->size(); i++) {
insertHelper(root, 0, (*cloud)[i], i);
}
}
void searchHelper(Node *&node, uint depth, std::vector<int> *ids, std::vector<float> target, float distanceTol) {
uint index = depth % 3;
if (node != NULL) {
if (((node->point[0] < target[0] + distanceTol) && (node->point[0] > target[0] - distanceTol)) &&
((node->point[1] < target[1] + distanceTol) && (node->point[1] > target[1] - distanceTol)) &&
((node->point[2] < target[2] + distanceTol) && (node->point[2] > target[2] - distanceTol))) {
uint dis = sqrt((node->point[0] - target[0]) * (node->point[0] - target[0]) +
(node->point[1] - target[1]) * (node->point[1] - target[1]) +
(node->point[2] - target[2]) * (node->point[2] - target[2]));
if (dis < distanceTol) {
ids->push_back(node->id);
}
}
if (target[index] - distanceTol < node->point[index]) {
searchHelper(node->left, depth + 1, ids, target, distanceTol);
}
if (target[index] + distanceTol > node->point[index]) {
searchHelper(node->right, depth + 1, ids, target, distanceTol);
}
}
}
std::vector<int> search(std::vector<float> target, float distanceTol) {
std::vector<int> ids;
uint depth = 0;
searchHelper(root, depth, &ids, target, distanceTol);
return ids;
}
};
#include <iostream>
#include "kdtree.h"
int main() {
std::cout << "KdTree Demonstration" << std::endl << std::endl;
// generate cloud on heap
std::vector<std::vector<float>>* inputCloud = new std::vector<std::vector<float>>;
for (size_t i = 0; i < 3000; i++) {
std::vector<float> tmpPoint(3, i);
inputCloud->push_back(tmpPoint);
}
// instantiate tree
KdTree tree;
tree.insertCloud(inputCloud);
// search for points
std::vector<int> returnedIDs;
std::vector<float> target = {1212.12, 1214.25, 1210.52};
returnedIDs = tree.search(target, 30);
std::cout << "Indices Found: " << std::endl;
for (size_t index = 0; index < returnedIDs.size(); index++) {
std::cout << returnedIDs[index] << " ";
}
std::cout << std::endl << std::endl << "Done." << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment