Created
September 22, 2015 00:56
-
-
Save goldsborough/fb5c8c508e577696745d to your computer and use it in GitHub Desktop.
A red-black-tree implementation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template<typename Key, typename Value> | |
class Tree | |
{ | |
public: | |
struct Pair | |
{ | |
Pair(const Key& k = Key(), const Value& v = Value()) | |
: key(k) | |
, value(v) | |
{ } | |
Key key; | |
Value value; | |
}; | |
Tree() | |
: _root(nullptr) | |
{ } | |
Tree(const std::initializer_list<Pair>& list) | |
: _root(nullptr) | |
{ | |
for (const auto& pair : list) insert(pair); | |
} | |
Tree(const Tree& other) | |
{ | |
_copy(other._root); | |
} | |
Tree(Tree&& other) noexcept | |
: Tree() | |
{ | |
swap(other); | |
} | |
Tree& operator=(Tree other) noexcept | |
{ | |
swap(other); | |
} | |
void swap(Tree& other) noexcept | |
{ | |
using std::swap; | |
swap(_root, other._root); | |
} | |
friend void swap(Tree& first, Tree& other) noexcept | |
{ | |
first.swap(other); | |
} | |
~Tree() | |
{ | |
clear(); | |
} | |
void insert(const Pair& pair) | |
{ | |
_root = _insert(_root, pair); | |
} | |
void insert(const Key& key, const Value& value) | |
{ | |
insert({key, value}); | |
} | |
void erase(const Key& key) | |
{ | |
_root = _erase(_root, key); | |
} | |
void clear() | |
{ | |
_clear(_root); | |
_root = nullptr; | |
} | |
bool contains(const Key& key) const | |
{ | |
return _find(_root, key); | |
} | |
Value& at(const Key& key) | |
{ | |
Node* node = _find(_root, key); | |
if (! node) throw std::invalid_argument("No such key!"); | |
return node->value; | |
} | |
const Value& at(const Key& key) const | |
{ | |
Node* node = _find(_root, key); | |
if (! node) throw std::invalid_argument("No such key!"); | |
return node->value; | |
} | |
Value& operator[](const Key& key) | |
{ | |
Node* node = _find_or_create(_root, key); | |
return node->value; | |
} | |
const Value& operator[](const Key& key) const | |
{ | |
Node* node = _find_or_create(_root, key); | |
return node->value; | |
} | |
const Key& minimum() const | |
{ | |
if (! _root) throw std::runtime_error("Tree is empty!"); | |
Node* node = _root; | |
while(node->left) node = node->left; | |
return node->key; | |
} | |
const Key& maximum() const | |
{ | |
if (! _root) throw std::runtime_error("Tree is empty!"); | |
Node* node = _root; | |
while(node->right) node = node->right; | |
return node->key; | |
} | |
const Key& floor(const Key& key) const | |
{ | |
const Node* node = _floor(_root, key); | |
if (! node) throw std::runtime_error("No floor for given key!"); | |
return node->key; | |
} | |
const Key& ceiling(const Key& key) const | |
{ | |
const Node* node = _ceiling(_root, key); | |
if (! node) throw std::runtime_error("No ceiling for given key!"); | |
return node->key; | |
} | |
std::size_t rank(const Key& key) const | |
{ | |
return _rank(_root, key); | |
} | |
const Key& select(std::size_t rank) const | |
{ | |
Node* node = _select(_root, rank); | |
if (! node) throw std::invalid_argument("No such key for given rank!"); | |
return node->key; | |
} | |
std::size_t size(std::size_t lower, std::size_t upper) const | |
{ | |
return _size(_root, lower, upper); | |
} | |
std::size_t size(const Key& lower, const Key& upper) const | |
{ | |
if (upper < lower) throw std::invalid_argument("First key must be " | |
"less than second!"); | |
std::size_t lower_rank = rank(lower); | |
std::size_t upper_rank = rank(upper); | |
return upper_rank - lower_rank + 1; | |
} | |
std::size_t size() const | |
{ | |
return _root ? _root->size : 0; | |
} | |
bool is_empty() const | |
{ | |
return ! _root; | |
} | |
std::vector<Key> keys(std::size_t lower, std::size_t upper) const | |
{ | |
std::vector<Key> destination; | |
_keys(_root, lower, upper, destination); | |
return destination; | |
} | |
std::vector<Key> keys() const | |
{ | |
std::vector<Key> k; | |
_collect(_root, std::back_inserter(k), | |
[] (const Node* node) { return node->key; }); | |
return k; | |
} | |
std::vector<Value> values() const | |
{ | |
std::vector<Value> v; | |
_collect(_root, std::back_inserter(v), | |
[] (const Node* node) { return node->value; }); | |
return v; | |
} | |
std::vector<Pair> pairs() const | |
{ | |
std::vector<Pair> p; | |
_collect(_root, | |
std::back_inserter(p), | |
[] (const Node* node) { return node->pair(); }); | |
return p; | |
} | |
private: | |
enum class Color | |
{ | |
RED, | |
BLACK | |
}; | |
struct Node | |
{ | |
Node(const Pair& pair) | |
: left(nullptr) | |
, right(nullptr) | |
, key(pair.key) | |
, value(pair.value) | |
, size(1) | |
, color(Color::RED) | |
{ } | |
void resize() | |
{ | |
size = 1; | |
if (left) size += left->size; | |
if (right) size += right->size; | |
} | |
Pair pair() const | |
{ | |
return {key, value}; | |
} | |
Node* left; | |
Node* right; | |
Key key; | |
Value value; | |
std::size_t size; | |
Color color; | |
}; | |
void _copy(Node* node) | |
{ | |
if (! node) return; | |
insert(node->key, node->value); | |
_copy(node->left); | |
_copy(node->right); | |
} | |
Node* _insert(Node* node, const Pair& pair) | |
{ | |
if (! node) return new Node(pair); | |
if(pair.key == node->key) node->value = pair.value; | |
else | |
{ | |
if (pair.key < node->key) | |
{ | |
node->left = _insert(node->left, pair); | |
} | |
else if (pair.key > node->key) | |
{ | |
node->right = _insert(node->right, pair); | |
} | |
node->resize(); | |
} | |
return _handle_color(node); | |
} | |
Node* _erase(Node* node, const Key& key) | |
{ | |
if (! node) throw std::invalid_argument("No such key!"); | |
if (key < node->key) | |
{ | |
node->left = _erase(node->left, key); | |
} | |
else if (key > node->key) | |
{ | |
node->right = _erase(node->right, key); | |
} | |
else | |
{ | |
if (! node->right) | |
{ | |
delete node; | |
return node->left; | |
} | |
if (! node->left) | |
{ | |
delete node; | |
return node->right; | |
} | |
Node* successor = node->right; | |
while (successor->left) | |
{ | |
successor->size--; | |
successor = successor->left; | |
} | |
_swap(node, successor); | |
delete successor; | |
} | |
node->resize(); | |
return _handle_color(node); | |
} | |
void _clear(Node* node) | |
{ | |
if (node) | |
{ | |
if (node->left) _clear(node->left); | |
if (node->right) _clear(node->right); | |
delete node; | |
} | |
} | |
const Node* _find(const Node* node, const Key& key) const | |
{ | |
if (node) | |
{ | |
if (key < node->key) return _find(node->left, key); | |
else if (key > node->key) return _find(node->right, key); | |
} | |
return node; | |
} | |
Node* _find_or_create(Node*& node, const Key& key) | |
{ | |
if (! node) node = new Node(key); | |
if (key < node->key) | |
{ | |
Node* target = _find_or_create(node->left, key); | |
node->resize(); | |
return _handle_color(target); | |
} | |
else if (key > node->key) | |
{ | |
Node* target = _find_or_create(node->right, key); | |
node->resize(); | |
return _handle_color(target); | |
} | |
return node; | |
} | |
std::size_t _rank(const Node* node, const Key& key) const | |
{ | |
if (! node) return 0; | |
if (key < node->key) return _rank(node->left, key); | |
std::size_t rank = 1; | |
if (node->left) rank += node->left->size; | |
if (key > node->key) rank += _rank(node->right, key); | |
return rank; | |
} | |
Node* _select(Node* node, std::size_t rank) const | |
{ | |
if (node) | |
{ | |
if (node->left) | |
{ | |
if (node->left->size > rank) | |
{ | |
return _select(node->left, rank); | |
} | |
else rank -= node->left->size; | |
} | |
rank -= 1; | |
if (rank) return _select(node->right, rank); | |
} | |
return node; | |
} | |
std::size_t _size(const Node* node, | |
std::size_t lower, | |
std::size_t upper) const | |
{ | |
std::size_t count = 0; | |
if (node) | |
{ | |
std::size_t rank = 1; | |
if (node->left) rank += node->left->size; | |
if (rank > upper) return _size(node->left, lower, upper); | |
else if (rank >= lower) | |
{ | |
count = 1; | |
count += _size(node->left, lower, upper); | |
count += _size(node->right, lower - rank, upper - rank); | |
} | |
} | |
return count; | |
} | |
void _keys(const Node* node, | |
std::size_t lower, | |
std::size_t upper, | |
std::vector<Key>& keys) const | |
{ | |
if (node) | |
{ | |
std::size_t rank = 1; | |
if (node->left) rank += node->left->size; | |
if (rank > upper) _keys(node->left, lower, upper, keys); | |
else if (rank >= lower) | |
{ | |
_keys(node->left, lower, upper, keys); | |
keys.push_back(node->key); | |
_keys(node->right, lower - rank, upper - rank, keys); | |
} | |
} | |
} | |
template<typename DestinationIterator, typename Function> | |
void _collect(const Node* node, | |
DestinationIterator itr, | |
const Function& function) const | |
{ | |
if (! node) return; | |
_collect(node->left, itr, function); | |
itr = function(node); | |
_collect(node->right, itr, function); | |
} | |
const Node* _floor(const Node* node, const Key& key) const | |
{ | |
if (! node) return nullptr; | |
if (key > node->key) | |
{ | |
const Node* found = _floor(node->right, key); | |
return found ? found : node; | |
} | |
else return _floor(node->left, key); | |
} | |
const Node* _ceiling(const Node* node, const Key& key) const | |
{ | |
if (! node) return nullptr; | |
if (key < node->key) | |
{ | |
const Node* found = _ceiling(node->left, key); | |
return found ? found : node; | |
} | |
else return _ceiling(node->right, key); | |
} | |
void _swap(Node* first, Node* second) | |
{ | |
Pair temp = {first->key, first->value}; | |
first->key = second->key; | |
first->value = second->value; | |
second->key = temp.key; | |
second->value = temp.value; | |
} | |
Node* _handle_color(Node* node) | |
{ | |
if (_is_red(node->right)) | |
{ | |
node = _rotate_left(node); | |
node->resize(); | |
node->left->resize(); | |
} | |
if (_is_red(node->left) && _is_red(node->left->left)) | |
{ | |
node = _rotate_right(node); | |
node->resize(); | |
node->right->resize(); | |
} | |
if (_is_red(node->left) && _is_red(node->right)) | |
{ | |
_flip_colors(node); | |
} | |
return node; | |
} | |
Node* _rotate_left(Node* node) | |
{ | |
Node* right = node->right; | |
assert(_is_red(right)); | |
node->right = right->left; | |
right->left = node; | |
right->color = node->color; | |
node->color = Color::RED; | |
return right; | |
} | |
Node* _rotate_right(Node* node) | |
{ | |
Node* left = node->left; | |
assert(_is_red(left)); | |
node->left = left->right; | |
left->right = node; | |
left->color = node->color; | |
node->color = Color::RED; | |
return left; | |
} | |
void _flip_colors(Node* node) | |
{ | |
assert(! _is_red(node)); | |
assert(_is_red(node->left)); | |
assert(_is_red(node->right)); | |
node->left->color = node->color; | |
node->right->color = node->color; | |
node->color = Color::RED; | |
} | |
bool _is_red(Node* node) | |
{ | |
if (! node || node == _root) return false; | |
return node->color == Color::RED; | |
} | |
Node* _root; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment