Skip to content

Instantly share code, notes, and snippets.

@goldsborough
Created September 22, 2015 00:56
Show Gist options
  • Save goldsborough/fb5c8c508e577696745d to your computer and use it in GitHub Desktop.
Save goldsborough/fb5c8c508e577696745d to your computer and use it in GitHub Desktop.
A red-black-tree implementation.
template<typename Key, typename Value>
class Tree
{
public:
struct Pair
{
Pair(const Key& k = Key(), const Value& v = Value())
: key(k)
, value(v)
{ }
Key key;
Value value;
};
Tree()
: _root(nullptr)
{ }
Tree(const std::initializer_list<Pair>& list)
: _root(nullptr)
{
for (const auto& pair : list) insert(pair);
}
Tree(const Tree& other)
{
_copy(other._root);
}
Tree(Tree&& other) noexcept
: Tree()
{
swap(other);
}
Tree& operator=(Tree other) noexcept
{
swap(other);
}
void swap(Tree& other) noexcept
{
using std::swap;
swap(_root, other._root);
}
friend void swap(Tree& first, Tree& other) noexcept
{
first.swap(other);
}
~Tree()
{
clear();
}
void insert(const Pair& pair)
{
_root = _insert(_root, pair);
}
void insert(const Key& key, const Value& value)
{
insert({key, value});
}
void erase(const Key& key)
{
_root = _erase(_root, key);
}
void clear()
{
_clear(_root);
_root = nullptr;
}
bool contains(const Key& key) const
{
return _find(_root, key);
}
Value& at(const Key& key)
{
Node* node = _find(_root, key);
if (! node) throw std::invalid_argument("No such key!");
return node->value;
}
const Value& at(const Key& key) const
{
Node* node = _find(_root, key);
if (! node) throw std::invalid_argument("No such key!");
return node->value;
}
Value& operator[](const Key& key)
{
Node* node = _find_or_create(_root, key);
return node->value;
}
const Value& operator[](const Key& key) const
{
Node* node = _find_or_create(_root, key);
return node->value;
}
const Key& minimum() const
{
if (! _root) throw std::runtime_error("Tree is empty!");
Node* node = _root;
while(node->left) node = node->left;
return node->key;
}
const Key& maximum() const
{
if (! _root) throw std::runtime_error("Tree is empty!");
Node* node = _root;
while(node->right) node = node->right;
return node->key;
}
const Key& floor(const Key& key) const
{
const Node* node = _floor(_root, key);
if (! node) throw std::runtime_error("No floor for given key!");
return node->key;
}
const Key& ceiling(const Key& key) const
{
const Node* node = _ceiling(_root, key);
if (! node) throw std::runtime_error("No ceiling for given key!");
return node->key;
}
std::size_t rank(const Key& key) const
{
return _rank(_root, key);
}
const Key& select(std::size_t rank) const
{
Node* node = _select(_root, rank);
if (! node) throw std::invalid_argument("No such key for given rank!");
return node->key;
}
std::size_t size(std::size_t lower, std::size_t upper) const
{
return _size(_root, lower, upper);
}
std::size_t size(const Key& lower, const Key& upper) const
{
if (upper < lower) throw std::invalid_argument("First key must be "
"less than second!");
std::size_t lower_rank = rank(lower);
std::size_t upper_rank = rank(upper);
return upper_rank - lower_rank + 1;
}
std::size_t size() const
{
return _root ? _root->size : 0;
}
bool is_empty() const
{
return ! _root;
}
std::vector<Key> keys(std::size_t lower, std::size_t upper) const
{
std::vector<Key> destination;
_keys(_root, lower, upper, destination);
return destination;
}
std::vector<Key> keys() const
{
std::vector<Key> k;
_collect(_root, std::back_inserter(k),
[] (const Node* node) { return node->key; });
return k;
}
std::vector<Value> values() const
{
std::vector<Value> v;
_collect(_root, std::back_inserter(v),
[] (const Node* node) { return node->value; });
return v;
}
std::vector<Pair> pairs() const
{
std::vector<Pair> p;
_collect(_root,
std::back_inserter(p),
[] (const Node* node) { return node->pair(); });
return p;
}
private:
enum class Color
{
RED,
BLACK
};
struct Node
{
Node(const Pair& pair)
: left(nullptr)
, right(nullptr)
, key(pair.key)
, value(pair.value)
, size(1)
, color(Color::RED)
{ }
void resize()
{
size = 1;
if (left) size += left->size;
if (right) size += right->size;
}
Pair pair() const
{
return {key, value};
}
Node* left;
Node* right;
Key key;
Value value;
std::size_t size;
Color color;
};
void _copy(Node* node)
{
if (! node) return;
insert(node->key, node->value);
_copy(node->left);
_copy(node->right);
}
Node* _insert(Node* node, const Pair& pair)
{
if (! node) return new Node(pair);
if(pair.key == node->key) node->value = pair.value;
else
{
if (pair.key < node->key)
{
node->left = _insert(node->left, pair);
}
else if (pair.key > node->key)
{
node->right = _insert(node->right, pair);
}
node->resize();
}
return _handle_color(node);
}
Node* _erase(Node* node, const Key& key)
{
if (! node) throw std::invalid_argument("No such key!");
if (key < node->key)
{
node->left = _erase(node->left, key);
}
else if (key > node->key)
{
node->right = _erase(node->right, key);
}
else
{
if (! node->right)
{
delete node;
return node->left;
}
if (! node->left)
{
delete node;
return node->right;
}
Node* successor = node->right;
while (successor->left)
{
successor->size--;
successor = successor->left;
}
_swap(node, successor);
delete successor;
}
node->resize();
return _handle_color(node);
}
void _clear(Node* node)
{
if (node)
{
if (node->left) _clear(node->left);
if (node->right) _clear(node->right);
delete node;
}
}
const Node* _find(const Node* node, const Key& key) const
{
if (node)
{
if (key < node->key) return _find(node->left, key);
else if (key > node->key) return _find(node->right, key);
}
return node;
}
Node* _find_or_create(Node*& node, const Key& key)
{
if (! node) node = new Node(key);
if (key < node->key)
{
Node* target = _find_or_create(node->left, key);
node->resize();
return _handle_color(target);
}
else if (key > node->key)
{
Node* target = _find_or_create(node->right, key);
node->resize();
return _handle_color(target);
}
return node;
}
std::size_t _rank(const Node* node, const Key& key) const
{
if (! node) return 0;
if (key < node->key) return _rank(node->left, key);
std::size_t rank = 1;
if (node->left) rank += node->left->size;
if (key > node->key) rank += _rank(node->right, key);
return rank;
}
Node* _select(Node* node, std::size_t rank) const
{
if (node)
{
if (node->left)
{
if (node->left->size > rank)
{
return _select(node->left, rank);
}
else rank -= node->left->size;
}
rank -= 1;
if (rank) return _select(node->right, rank);
}
return node;
}
std::size_t _size(const Node* node,
std::size_t lower,
std::size_t upper) const
{
std::size_t count = 0;
if (node)
{
std::size_t rank = 1;
if (node->left) rank += node->left->size;
if (rank > upper) return _size(node->left, lower, upper);
else if (rank >= lower)
{
count = 1;
count += _size(node->left, lower, upper);
count += _size(node->right, lower - rank, upper - rank);
}
}
return count;
}
void _keys(const Node* node,
std::size_t lower,
std::size_t upper,
std::vector<Key>& keys) const
{
if (node)
{
std::size_t rank = 1;
if (node->left) rank += node->left->size;
if (rank > upper) _keys(node->left, lower, upper, keys);
else if (rank >= lower)
{
_keys(node->left, lower, upper, keys);
keys.push_back(node->key);
_keys(node->right, lower - rank, upper - rank, keys);
}
}
}
template<typename DestinationIterator, typename Function>
void _collect(const Node* node,
DestinationIterator itr,
const Function& function) const
{
if (! node) return;
_collect(node->left, itr, function);
itr = function(node);
_collect(node->right, itr, function);
}
const Node* _floor(const Node* node, const Key& key) const
{
if (! node) return nullptr;
if (key > node->key)
{
const Node* found = _floor(node->right, key);
return found ? found : node;
}
else return _floor(node->left, key);
}
const Node* _ceiling(const Node* node, const Key& key) const
{
if (! node) return nullptr;
if (key < node->key)
{
const Node* found = _ceiling(node->left, key);
return found ? found : node;
}
else return _ceiling(node->right, key);
}
void _swap(Node* first, Node* second)
{
Pair temp = {first->key, first->value};
first->key = second->key;
first->value = second->value;
second->key = temp.key;
second->value = temp.value;
}
Node* _handle_color(Node* node)
{
if (_is_red(node->right))
{
node = _rotate_left(node);
node->resize();
node->left->resize();
}
if (_is_red(node->left) && _is_red(node->left->left))
{
node = _rotate_right(node);
node->resize();
node->right->resize();
}
if (_is_red(node->left) && _is_red(node->right))
{
_flip_colors(node);
}
return node;
}
Node* _rotate_left(Node* node)
{
Node* right = node->right;
assert(_is_red(right));
node->right = right->left;
right->left = node;
right->color = node->color;
node->color = Color::RED;
return right;
}
Node* _rotate_right(Node* node)
{
Node* left = node->left;
assert(_is_red(left));
node->left = left->right;
left->right = node;
left->color = node->color;
node->color = Color::RED;
return left;
}
void _flip_colors(Node* node)
{
assert(! _is_red(node));
assert(_is_red(node->left));
assert(_is_red(node->right));
node->left->color = node->color;
node->right->color = node->color;
node->color = Color::RED;
}
bool _is_red(Node* node)
{
if (! node || node == _root) return false;
return node->color == Color::RED;
}
Node* _root;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment