Last active
September 24, 2019 14:29
-
-
Save jamesgregson/20d9dfb4f7f49c33a90bf27324c00e91 to your computer and use it in GitHub Desktop.
Simple KDTree (3D)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef __KD_TREE_HEADER_H | |
#define __KD_TREE_HEADER_H | |
#include <set> | |
#include <list> | |
#include <tuple> | |
#include <queue> | |
#include <vector> | |
#include <algorithm> | |
namespace graphics { | |
struct kdtree_bounds { | |
float lo[3]; | |
float hi[3]; | |
kdtree_bounds box_union( const kdtree_bounds& in ) const { | |
return { | |
{ std::min(lo[0],in.lo[0]), std::min(lo[1],in.lo[1]), std::min(lo[2],in.lo[2]) }, | |
{ std::max(hi[0],in.hi[0]), std::max(hi[1],in.hi[1]), std::max(hi[2],in.hi[2]) } | |
}; | |
} | |
kdtree_bounds box_intersection( const kdtree_bounds& in ) const { | |
return { | |
{ std::max(lo[0],in.lo[0]), std::max(lo[1],in.lo[1]), std::max(lo[2],in.lo[2]) }, | |
{ std::min(hi[0],in.hi[0]), std::min(hi[1],in.hi[1]), std::min(hi[2],in.hi[2]) } | |
}; | |
} | |
bool intersects( const kdtree_bounds& in ) const { | |
return hi[0] >= in.lo[0] && lo[0] <= in.hi[0] | |
&& hi[1] >= in.lo[1] && lo[1] <= in.hi[1] | |
&& hi[2] >= in.lo[2] && lo[2] <= in.hi[2]; | |
} | |
}; | |
struct kdtree_node { | |
int split_axis; | |
float split_coord; | |
int children = -1; | |
std::vector<int> items; | |
}; | |
class kdtree { | |
public: | |
template< typename ItemList > | |
kdtree( const size_t num_items, const ItemList& items, int max_items=20, int max_levels=25 ){ | |
// get the union of all the input boxes | |
kdtree_bounds bnd = items[0]; | |
for( auto i=0; i<num_items; ++i ){ | |
bnd = bnd.box_union( items[i] ); | |
} | |
// pre-allocate all the nodes | |
m_nodes.resize( 1<<max_levels ); | |
for( auto i=0; i<num_items; ++i ){ | |
m_nodes[0].items.push_back(i); | |
} | |
// split heuristic, split longest edge. this SUCKS for raytracing. | |
auto get_split = [&,this]( kdtree_bounds bnds, const kdtree_node& node ){ | |
const float d[] = {bnds.hi[0]-bnds.lo[0],bnds.hi[1]-bnds.lo[1],bnds.hi[2]-bnds.lo[2]}; | |
int best = 0; | |
best = d[1] > d[best] ? 1 : best; | |
best = d[2] > d[best] ? 2 : best; | |
return std::make_tuple(best,bnds.lo[best]+d[best]*0.5f); | |
}; | |
// initialize a two queues | |
int next=1, nid, cid, axis; | |
float split; | |
kdtree_bounds nbnds, bnd0, bnd1; | |
std::list<std::tuple<int,kdtree_bounds>> queue, queue_new; | |
// main subdivision loop, process level by level | |
const size_t N = m_nodes.size(); | |
queue.push_back(std::make_tuple(0,bnd)); | |
for( auto level=0; level<max_levels-1; ++level ){ | |
// subdivision for level | |
while( !queue.empty() ){ | |
// retrieve the next node id and bounding box | |
std::tie(nid,nbnds) = queue.front(); queue.pop_front(); | |
if( next >= N-2 || m_nodes[nid].items.size() <= max_items ) | |
continue; | |
// compute the split location | |
cid = next; next += 2; | |
std::tie(axis,split) = get_split( nbnds, m_nodes[nid] ); | |
m_nodes[nid].split_axis = axis; | |
m_nodes[nid].split_coord = split; | |
m_nodes[nid].children = cid; | |
m_nodes[cid+0].children = 0; | |
m_nodes[cid+1].children = 0; | |
// add the items to the children | |
for( auto idx : m_nodes[nid].items ){ | |
if( items[idx].lo[axis] <= split ){ | |
m_nodes[cid+0].items.push_back( idx ); | |
} | |
if( items[idx].hi[axis] >= split ){ | |
m_nodes[cid+1].items.push_back( idx ); | |
} | |
} | |
// remove all the items from the split node | |
m_nodes[nid].items.clear(); | |
// clip the bounds and add them to the queue | |
bnd0 = nbnds; | |
bnd0.hi[axis] = split; | |
bnd1 = nbnds; | |
bnd1.lo[axis] = split; | |
queue_new.push_back(std::make_tuple(cid+0,bnd0)); | |
queue_new.push_back(std::make_tuple(cid+1,bnd1)); | |
} | |
// swap the queue contents | |
std::swap( queue, queue_new ); | |
} | |
} | |
template< typename distance_func > | |
std::tuple<float,int> query_closest( distance_func& dis_fn, const kdtree_bounds& bnd ){ | |
// degenerate case of no children for root node....c'mon. | |
if( m_nodes[0].children <= 0 ){ | |
return dis_fn( m_nodes[0].items, bnd ); | |
} | |
int axis,nid,best,tmp; | |
float mind=1e10f,d,d0,d1,split; | |
// std::priority_queue< | |
// std::tuple<float,int>, | |
// std::vector<std::tuple<float,int>>, | |
// std::greater<std::tuple<float,int>> > queue; | |
std::set<std::tuple<float,int>> queue; | |
axis = m_nodes[0].split_axis; | |
split = m_nodes[0].split_coord; | |
d0 = std::max( 0.0f, bnd.lo[axis]-split ); | |
d1 = std::max( 0.0f, split-bnd.hi[axis] ); | |
if( d0 <= d1 ){ | |
//queue.push(std::make_tuple(d1*d1,m_nodes[0].children+1) ); | |
queue.insert(std::make_tuple(d0*d0,m_nodes[0].children+0) ); | |
queue.insert(std::make_tuple(d1*d1,m_nodes[0].children+1) ); | |
} else { | |
queue.insert(std::make_tuple(d1*d1,m_nodes[0].children+1) ); | |
queue.insert(std::make_tuple(d0*d0,m_nodes[0].children+0) ); | |
} | |
// search... | |
const size_t N = m_nodes.size(); | |
while( !queue.empty() ){ | |
// get the next closest entry | |
std::tie(d,nid) = *queue.begin(); | |
queue.erase(queue.begin()); | |
// std::tie(d,nid) = queue.top(); | |
// queue.pop(); | |
// distance is greater than or equal to | |
// minimum possible distance, closest | |
// point has already been found! | |
if( nid >= N || d >= mind ) | |
break; | |
if( m_nodes[nid].children <= 0 ){ | |
// node does not have children, get the | |
// minimum distance squared to all items | |
// within the node | |
std::tie(d,tmp) = dis_fn( m_nodes[nid].items, bnd ); | |
if( d < mind ){ | |
best = tmp; | |
mind = d; | |
} | |
} else { | |
// node does have children, add them to | |
// the queue in increasing order of distance | |
axis = m_nodes[nid].split_axis; | |
split = m_nodes[nid].split_coord; | |
d0 = std::max( 0.0f, bnd.lo[axis]-split ); | |
d1 = std::max( 0.0f, split-bnd.hi[axis] ); | |
if( d0 <= d1 ){ | |
queue.insert(std::make_tuple(d0*d0,m_nodes[nid].children+0) ); | |
queue.insert(std::make_tuple(d1*d1,m_nodes[nid].children+1) ); | |
} else { | |
queue.insert(std::make_tuple(d1*d1,m_nodes[nid].children+1) ); | |
queue.insert(std::make_tuple(d0*d0,m_nodes[nid].children+0) ); | |
} | |
} | |
} | |
return std::make_tuple(mind,best); | |
} | |
private: | |
std::vector<kdtree_node> m_nodes; | |
}; | |
}; | |
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "kd_tree.h" | |
#include <ctime> | |
#include <cstdlib> | |
#include <iostream> | |
#include <stdexcept> | |
#include <sys/time.h> | |
double curr_time(){ | |
struct timeval tv; | |
gettimeofday( &tv, NULL ); | |
return double(tv.tv_sec) + 1e-6*double(tv.tv_usec); | |
} | |
graphics::kdtree_bounds make_bounds( float x, float y, float z ){ | |
return {{x,y,z},{x,y,z}}; | |
} | |
std::ostream& operator<<( std::ostream& os, const graphics::kdtree_bounds& bnd ){ | |
os << "[" << bnd.lo[0] << ", " << bnd.lo[1] << ", " << bnd.lo[2] << "]"; | |
return os; | |
} | |
class point_set { | |
public: | |
point_set( const std::vector<float>& pnts ) : m_pnts(pnts) { | |
} | |
size_t size() const { | |
return m_pnts.size()/3; | |
} | |
graphics::kdtree_bounds operator[]( const size_t &idx ) const { | |
return make_bounds( m_pnts[idx*3+0], m_pnts[idx*3+1], m_pnts[idx*3+2] ); | |
} | |
private: | |
const std::vector<float> &m_pnts; | |
}; | |
int main( int argc, char **argv ){ | |
std::vector<int> items; | |
std::vector<float> pnts; | |
for( auto i=0; i<100000; i++ ){ | |
pnts.push_back( drand48() ); | |
pnts.push_back( drand48() ); | |
pnts.push_back( drand48() ); | |
items.push_back(i); | |
} | |
auto ps = point_set(pnts); | |
auto min_dis_func = [&]( const std::vector<int>& items, const graphics::kdtree_bounds& bnd ){ | |
int best; | |
float dis, min_dis = 1e10f; | |
for( auto idx : items ){ | |
float delta[] = { pnts[idx*3+0]-bnd.lo[0], pnts[idx*3+1]-bnd.lo[1], pnts[idx*3+2]-bnd.lo[2] }; | |
dis = delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2]; | |
if( dis < min_dis ){ | |
min_dis = dis; | |
best = idx; | |
} | |
} | |
return std::make_tuple( min_dis, best ); | |
}; | |
graphics::kdtree kdtree( pnts.size()/3, point_set(pnts) ); | |
std::cout << "done building tree..." << std::endl; | |
srand(time(NULL)); | |
srand48(time(NULL)); | |
int best_kd, best_bf; | |
float min_dis_kd, min_dis_bf, total_dis; | |
for( auto i=0; i<1000; ++i ){ | |
auto bnd = make_bounds(drand48(),drand48(),drand48()); | |
std::tie(min_dis_kd,best_kd) = kdtree.query_closest( min_dis_func, bnd ); | |
std::tie(min_dis_bf,best_bf) = min_dis_func( items, bnd ); | |
if( best_kd != best_bf ){ | |
std::cout << best_kd << " " << best_bf << std::endl; | |
std::cout << min_dis_kd << " " << min_dis_bf << std::endl; | |
} | |
} | |
int N = 100000; | |
std::vector<float> x,y,z; | |
for( auto i=0; i<N; ++i ){ | |
x.push_back(drand48()); | |
y.push_back(drand48()); | |
z.push_back(drand48()); | |
} | |
total_dis = 0.0f; | |
double t = curr_time(); | |
for( auto i=0; i<N; ++i ){ | |
auto bnd = make_bounds(x[i],y[i],z[i]); | |
std::tie(min_dis_kd,best_kd) = kdtree.query_closest( min_dis_func, bnd ); | |
total_dis += min_dis_kd; | |
} | |
std::cout << "Total time: " << (curr_time()-t)*1e6/double(N) << "us/item" << std::endl; | |
std::cout << "Total distance: " << total_dis << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment