Skip to content

Instantly share code, notes, and snippets.

@lnicola
Created June 24, 2014 14:48
Show Gist options
  • Save lnicola/44d9b27751bdb584f8de to your computer and use it in GitHub Desktop.
Save lnicola/44d9b27751bdb584f8de to your computer and use it in GitHub Desktop.
#include <algorithm>
#include <array>
#include <chrono>
#include <iostream>
#include <numeric>
#include <random>
#include <vector>
#include <mmintrin.h>
using namespace std;
using namespace std::chrono;
template<class _FwdIt,
class _Ty> inline
static _FwdIt my_upper_bound(_FwdIt _First, _FwdIt _Last, const _Ty _Val)
{
for (auto p = reinterpret_cast<char *>(_First); p < reinterpret_cast<char *>(_Last); p += 64)
_mm_prefetch(reinterpret_cast<char *>(p), _MM_HINT_T0);
return upper_bound(_First, _Last, _Val);
}
class stopwatch
{
high_resolution_clock::time_point start;
public:
stopwatch()
: start(high_resolution_clock::now())
{
}
~stopwatch()
{
auto duration = high_resolution_clock::now() - start;
cout << duration_cast<milliseconds>(duration).count() << " ms" << endl;
}
};
template<typename T, short t = 2>
struct b_node
{
short num_keys;
bool leaf;
T keys[2 * t - 1];
b_node(short num_keys) : num_keys(num_keys), leaf(true)
{
}
bool is_full()
{
return num_keys == 2 * t - 1;
}
};
template<typename T, short t = 2>
struct b_node_internal : b_node<T, t>
{
b_node<T, t> *children[2 * t];
b_node_internal(short num_keys) : b_node<T, t>(num_keys)
{
this->leaf = false;
}
void split_child(short idx)
{
auto child = children[idx];
b_node<T, t> *new_child = child->leaf ? new b_node<T, t>(t - 1) : new b_node_internal<T, t>(t - 1);
std::copy(&child->keys[t], &child->keys[2 * t - 1], &new_child->keys[0]);
if (!child->leaf)
{
auto internal_node_new_child = static_cast<b_node_internal<T, t> *>(new_child);
auto internal_node_child = static_cast<b_node_internal<T, t> *>(child);
std::copy(&internal_node_child->children[t], &internal_node_child->children[2 * t], &internal_node_new_child->children[0]);
}
child->num_keys = t - 1;
std::copy_backward(&children[idx], &children[this->num_keys + 1], &children[this->num_keys + 2]);
children[idx + 1] = new_child;
std::copy_backward(&this->keys[idx], &this->keys[this->num_keys], &this->keys[this->num_keys + 1]);
this->keys[idx] = child->keys[t - 1];
this->num_keys++;
}
};
template<typename T, short t = 2>
class b_tree
{
b_node<T, t> *root;
static void insert_nonfull(b_node<T, t> *node, const T data)
{
while (true)
{
auto i = my_upper_bound(&node->keys[0], &node->keys[node->num_keys], data) - &node->keys[0];
if (node->leaf)
{
std::copy_backward(&node->keys[i], &node->keys[node->num_keys], &node->keys[node->num_keys + 1]);
node->keys[i] = data;
node->num_keys++;
return;
}
auto internal_node = static_cast<b_node_internal<T, t> *>(node);
if (internal_node->children[i]->is_full())
{
internal_node->split_child(i);
if (internal_node->keys[i] < data) i++;
}
node = internal_node->children[i];
}
}
public:
b_tree() : root(new b_node<T, t>(0)) { }
void insert(const T data)
{
if (root->is_full())
{
auto r = new b_node_internal<T, t>(0);
r->children[0] = root;
r->split_child(0);
root = r;
}
insert_nonfull(root, data);
}
bool search(const T data)
{
auto node = root;
while (true)
{
auto i = my_upper_bound(node->keys, node->keys + node->num_keys, data) - node->keys;
if (i > 0 && node->keys[i - 1] == data)
return true;
if (node->leaf)
return false;
node = static_cast<b_node_internal<T, t> *>(node)->children[i];
}
}
};
int main()
{
const int count = 1 * 1000 * 1000;
vector<long long> numbers(count);
mt19937_64 gen;
for (int i = 0; i < count; i++)
numbers[i] = gen();
//iota(begin(numbers), end(numbers), 0);
b_tree<long long, 32> tree;
{
stopwatch sw;
for (auto val : numbers)
tree.insert(val);
}
int found = 0;
{
stopwatch sw;
for (auto val : numbers)
{
if (tree.search(val))
found++;
}
}
cout << found << endl;
}
//
//template<typename T, int N>
//class sorted_array
//{
// std::array<T, N> data_;
// int n_ = 0;
//
//public:
// unsigned long long tries = 0;
//
// __declspec(noinline) bool insert(T val)
// {
// auto pos = std::upper_bound(begin(data_), begin(data_) + n_, val);
// if (pos != begin(data_) && *(pos - 1) == val)
// return false;
//
// std::copy_backward(pos, begin(data_) + n_, begin(data_) + n_ + 1);
// *pos = val;
// n_++;
// return true;
// }
//
// __declspec(noinline) bool search(T val)
// {
// int *start = &data_[0], *end = start + n_;
// while (start < end)
// {
// tries++;
//
// auto m = start + (end - start) / 2;
// if (val < *m)
// end = m;
// else if (*m < val)
// start = m + 1;
// else
// return true;
// }
// return false;
// //return std::binary_search(begin(data_), begin(data_) + n_, val);
// }
//};
//
//template<typename T>
//class bst_node
//{
// T val_;
//
//public:
// bst_node<T> *left_, *right_;
//
// bst_node(T val)
// : val_(val), left_(), right_()
// {
// }
//
// ~bst_node()
// {
// delete left;
// delete right;
// }
//
// T key() const
// {
// return val_;
// }
//};
//
//template<typename T>
//class bst
//{
// bst_node<T> *root_ = nullptr;
//
//public:
// unsigned long long tries = 0;
//
// __declspec(noinline) bool insert(T val)
// {
// bst_node<T> **p = &root_;
// while (true)
// {
// if (!*p)
// {
// *p = new bst_node<T>(val);
// return true;
// }
//
// if (val < (*p)->key())
// p = &((*p)->left_);
// else if ((*p)->key() < val)
// p = &((*p)->right_);
// else
// return false;
// }
// }
//
// __declspec(noinline) bool search(T val)
// {
// auto p = root_;
// while (p)
// {
// tries++;
//
// if (val < p->key())
// p = p->left_;
// else if (p->key() < val)
// p = p->right_;
// else
// return true;
// }
// return false;
// }
//};
//int main2()
//{
// const int runs = 100000 * 16;
// const int count = 64;
//
// vector<int> numbers(count);
// std::iota(begin(numbers), end(numbers), 0);
// vector<int> queries(numbers);
// std::random_shuffle(begin(numbers), end(numbers));
// std::random_shuffle(begin(queries), end(queries));
//
// sorted_array<int, count> sa;
// int c = 0;
//
// for (auto val : numbers)
// sa.insert(val);
// {
// stopwatch sw;
// for (int run = 0; run < runs; run++)
// {
// for (auto val : queries)
// {
// if (sa.search(val))
// c++;
// }
// for (auto val : queries)
// {
// if (sa.search(val))
// c++;
// }
// }
// }
//
// cout << c << endl;
// c = 0;
//
// bst<int> bst;
//
// for (auto val : numbers)
// bst.insert(val);
//
// {
// stopwatch sw;
// for (long long run = 0; run < runs; run++)
// {
// for (auto val : queries)
// {
// if (bst.search(val))
// c++;
// }
// for (auto val : queries)
// {
// if (bst.search(val))
// c++;
// }
// }
// }
//
// cout << c << endl;
//
// cout << sa.tries << endl << bst.tries << endl;
// return c;
//}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment