Last active
February 24, 2018 07:41
-
-
Save zhenghaoz/1743e089bfa3b8ba3074bf45d7cdda19 to your computer and use it in GitHub Desktop.
B Tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <memory> | |
#include <vector> | |
#include <iostream> | |
template <unsigned N, typename Key, typename Value> | |
class BTree | |
{ | |
template <typename T> using vector = std::vector<T>; | |
template <typename T> using shared_ptr = std::shared_ptr<T>; | |
struct Node | |
{ | |
bool _leaf; | |
int _size; | |
vector<Key> _keys = vector<Key>(2*N-1); | |
vector<Value> _values = vector<Value>(2*N-1); | |
vector<shared_ptr<Node>> _children = vector<shared_ptr<Node>>(2*N); | |
Node() = default; | |
Node(const Node &node): _leaf(node._leaf), _size(node._size), _keys(node._keys), _values(node._values) | |
{ | |
if (!_leaf) | |
for (int i = 0; i <= _size; i++) | |
_children[i] = std::make_shared<Node>(*node._children[i]); | |
} | |
}; | |
shared_ptr<Node> root; | |
// find k-v in a node | |
Value* find(shared_ptr<Node> node, Key key) | |
{ | |
// search key in node | |
int i = 0; | |
while (i < node->_size && key > node->_keys[i]) | |
i++; | |
if (i < node->_size && key == node->_keys[i]) | |
return &node->_values[i]; | |
else if (node->_leaf) | |
return nullptr; | |
else return find(node->_children[i], key); | |
} | |
// split a full node (child.size == 2*N-1) | |
void split(shared_ptr<Node> parent, int i, shared_ptr<Node> child) | |
{ | |
shared_ptr<Node> nchild = std::make_shared<Node>(); | |
nchild->_leaf = child->_leaf; | |
nchild->_size = child->_size = N-1; | |
// move k-v | |
for (int j = 0; j < N-1; j++) { | |
nchild->_keys[j] = child->_keys[j + N]; | |
nchild->_values[j] = child->_values[j + N]; | |
} | |
// move children | |
if (!child->_leaf) | |
for (int j = 0; j < N; j++) | |
nchild->_children[j] = child->_children[j + N]; | |
// move child->key[N-1] up | |
for (int j = parent->_size; j > i; j--) { | |
parent->_keys[j] = parent->_keys[j-1]; | |
parent->_values[j] = parent->_values[j-1]; | |
parent->_children[j+1] = parent->_children[j]; | |
} | |
parent->_keys[i] = child->_keys[N-1]; | |
parent->_values[i] = child->_values[N-1]; | |
parent->_children[i+1] = nchild; | |
parent->_size++; | |
} | |
// insert k-v in a node | |
void insert(shared_ptr<Node> node, Key key, Value value) | |
{ | |
// find insert position | |
int i = 0; | |
while (i < node->_size && key > node->_keys[i]) | |
i++; | |
if (node->_leaf) { // insert k-v in a leaf | |
for (int j = node->_size; j > i; j--) { | |
node->_keys[j] = node->_keys[j-1]; | |
node->_values[j] = node->_values[j-1]; | |
} | |
node->_keys[i] = key; | |
node->_values[i] = value; | |
node->_size++; | |
} else { // insert k-v in subNode | |
shared_ptr<Node> ptr = node->_children[i]; | |
if (ptr->_size == 2*N-1) { | |
split(node, i, ptr); | |
if (key > node->_keys[i]) | |
i++; | |
} | |
insert(node->_children[i], key, value); | |
} | |
} | |
// insert k-v in root node | |
void insert(Key key, Value value) | |
{ | |
shared_ptr<Node> ptr = root; | |
if (ptr->_size == 2*N-1) { // split root node | |
root = std::make_shared<Node>(); | |
root->_leaf = false; | |
root->_size = 0; | |
root->_children[0] = ptr; | |
split(root, 0, ptr); | |
insert(root, key, value); | |
} else insert(root, key, value); | |
} | |
// combine children[i] and children[i+1] | |
void combine(shared_ptr<Node> parent, int i) | |
{ | |
shared_ptr<Node> prev = parent->_children[i]; | |
shared_ptr<Node> next = parent->_children[i+1]; | |
// move parent->key[i] down | |
prev->_keys[prev->_size] = parent->_keys[i]; | |
prev->_values[prev->_size] = parent->_values[i]; | |
prev->_size++; | |
// move k-v from next to prev | |
for (int j = 0; j < next->_size; j++) { | |
prev->_keys[j + prev->_size] = next->_keys[j]; | |
prev->_values[j + prev->_size] = next->_values[j]; | |
} | |
if (!prev->_leaf) | |
for (int j = 0; j <= next->_size; j++) | |
prev->_children[j + prev->_size] = next->_children[j]; | |
prev->_size += next->_size; | |
// remove parent->key[i] | |
parent->_size--; | |
for (int j = i; j < parent->_size; j++) { | |
parent->_keys[j] = parent->_keys[j+1]; | |
parent->_values[j] = parent->_values[j+1]; | |
parent->_children[j+1] = parent->_children[j+2]; | |
} | |
} | |
shared_ptr<Node> max(shared_ptr<Node> node) | |
{ | |
shared_ptr<Node> ptr = node; | |
while (!ptr->_leaf) | |
ptr = ptr->_children[ptr->_size]; | |
return ptr; | |
} | |
shared_ptr<Node> min(shared_ptr<Node> node) | |
{ | |
shared_ptr<Node> ptr = node; | |
while (!ptr->_leaf) | |
ptr = ptr->_children[0]; | |
return ptr; | |
} | |
// remove key from node, key must be in node | |
void remove(shared_ptr<Node> node, Key key) | |
{ | |
// find delete position | |
int i = 0; | |
while (i < node->_size && key > node->_keys[i]) | |
i++; | |
if (node->_leaf) { // case 1: remove k-v from leaf | |
node->_size--; | |
for (int j = i; j < node->_size; j++) { | |
node->_keys[j] = node->_keys[j+1]; | |
node->_values[j] = node->_values[j+1]; | |
} | |
} else if (i < node->_size && key == node->_keys[i]) { // case 2: find key in internal node | |
shared_ptr<Node> prevChild = node->_children[i]; | |
shared_ptr<Node> nextChild = node->_children[i+1]; | |
if (prevChild->_size >= N) { // case 2a: move precursor to the position of key | |
shared_ptr<Node> maxNode = max(prevChild); | |
node->_keys[i] = maxNode->_keys[maxNode->_size-1]; | |
node->_values[i] = maxNode->_values[maxNode->_size-1]; | |
remove(prevChild, maxNode->_keys[maxNode->_size-1]); | |
} else if (nextChild->_size >= N) { // case 2b: move successor to the position of key | |
shared_ptr<Node> minNode = min(nextChild); | |
node->_keys[i] = minNode->_keys[0]; | |
node->_values[i] = minNode->_values[0]; | |
remove(nextChild, minNode->_keys[0]); | |
} else { // case 2c: combine previous child and next child | |
combine(node, i); | |
remove(node->_children[i], key); | |
} | |
} else { // case 3 | |
shared_ptr<Node> subNode = node->_children[i]; | |
if (subNode->_size < N) { | |
shared_ptr<Node> prevBrother, nextBrother; | |
if (i > 0) prevBrother = node->_children[i-1]; | |
if (i < node->_size) nextBrother = node->_children[i+1]; | |
if (prevBrother && prevBrother->_size >= N) { // case 3a | |
// remove node->key[i] into subNode | |
for (int j = subNode->_size; j > 0; j--) { | |
subNode->_keys[j] = subNode->_keys[j-1]; | |
subNode->_values[j] = subNode->_values[j-1]; | |
} | |
if (!subNode->_leaf) | |
for (int j = subNode->_size; j >= 0; j--) | |
subNode->_children[j+1] = subNode->_children[j]; | |
subNode->_keys[0] = node->_keys[i-1]; | |
subNode->_values[0] = node->_values[i-1]; | |
subNode->_children[0] = prevBrother->_children[prevBrother->_size]; | |
subNode->_size++; | |
// remove prevBrother->key[prevBrother->size-1] into node | |
node->_keys[i-1] = prevBrother->_keys[prevBrother->_size-1]; | |
node->_values[i-1] = prevBrother->_values[prevBrother->_size-1]; | |
prevBrother->_size--; | |
} else if (nextBrother && nextBrother->_size >= N) { // case 3a | |
// remove node->key[i] into subNode | |
subNode->_keys[subNode->_size] = node->_keys[i]; | |
subNode->_values[subNode->_size] = node->_values[i]; | |
subNode->_children[subNode->_size+1] = nextBrother->_children[0]; | |
subNode->_size++; | |
// remove nextBrother->key[0] into node | |
node->_keys[i] = nextBrother->_keys[0]; | |
node->_values[i] = nextBrother->_values[0]; | |
nextBrother->_size--; | |
for (int j = 0; j < nextBrother->_size; j++) { | |
nextBrother->_keys[j] = nextBrother->_keys[j+1]; | |
nextBrother->_values[j] = nextBrother->_values[j+1]; | |
} | |
if (!nextBrother->_leaf) | |
for (int j = 0; j <= nextBrother->_size; j++) | |
nextBrother->_children[j] = nextBrother->_children[j+1]; | |
} else if (nextBrother) { // case 3b: combine child[i] and child[i+1] | |
combine(node, i); | |
} else { // case 3b: combine child[i-1] and child[i] | |
i--; | |
combine(node, i); | |
} | |
} | |
remove(node->_children[i], key); | |
} | |
} | |
void print(shared_ptr<Node> node, int level) | |
{ | |
for (int i = 0; i < level; i++) | |
std::cout << ' '; | |
std::cout << "{"; | |
for (int i = 0; i < node->_size; i++) { | |
if (i) std::cout << ","; | |
std::cout << node->_keys[i] << ":" << node->_values[i]; | |
} | |
std::cout << "}" << std::endl; | |
if (!node->_leaf) | |
for (int i = 0; i <= node->_size; i++) | |
print(node->_children[i], level+1); | |
} | |
public: | |
BTree() | |
{ | |
root = std::make_shared<Node>(); | |
root->_leaf = true; | |
root->_size = 0; | |
} | |
BTree(const BTree &tree) | |
{ | |
root = std::make_shared<Node>(*tree.root); | |
} | |
BTree& operator=(BTree tree) | |
{ | |
std::swap(root, tree.root); | |
} | |
Value* get(const Key &key) | |
{ | |
return find(root, key); | |
} | |
void put(const Key &key, const Value &value) | |
{ | |
Value *val = find(root, key); | |
if (val) | |
*val = value; | |
else | |
insert(key, value); | |
} | |
void remove(const Key &key) | |
{ | |
if (find(root, key)) | |
remove(root, key); | |
if (root->_size == 0) | |
root = root->_children[0]; | |
if (root == nullptr) { | |
root = std::make_shared<Node>(); | |
root->_leaf = true; | |
root->_size = 0; | |
} | |
} | |
void print() | |
{ | |
print(root, 0); | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment