Skip to content

Instantly share code, notes, and snippets.

@goldsborough
Created November 5, 2015 00:56
Show Gist options
  • Save goldsborough/529689a484b58b913c94 to your computer and use it in GitHub Desktop.
Save goldsborough/529689a484b58b913c94 to your computer and use it in GitHub Desktop.
Red-black Tree <3
template<typename Key, typename Value>
class RBTree
{
public:
using size_t = std::size_t;
RBTree() noexcept
: _root(nullptr)
{ }
RBTree(const std::initializer_list<std::pair<Key, Value>>& list)
: _root(nullptr)
{
for (const auto& item : list)
{
insert(item.first, item.second);
}
}
RBTree(const RBTree& other)
{
_root = _copy(other._root);
}
RBTree(RBTree&& other) noexcept
: RBTree()
{
swap(other);
}
RBTree& operator=(RBTree other)
{
swap(other);
return *this;
}
void swap(RBTree& other) noexcept
{
// Enable ADL
using std::swap;
swap(_root, other._root);
}
friend void swap(RBTree& first, RBTree& second) noexcept
{
first.swap(second);
}
~RBTree()
{
_clear(_root);
}
void insert(const Key& key, const Value& value)
{
_root = _insert(_root, key, value);
}
void erase(const Key& key)
{
_root = _erase(_root, key);
}
void clear()
{
_clear(_root);
}
Value& get(const Key& key)
{
auto node = _find(_root, key);
if (! node)
{
throw std::invalid_argument("No such key!");
}
return node->value;
}
const Value& get(const Key& key) const
{
auto node = _find(_root, key);
if (! node)
{
throw std::invalid_argument("No such key!");
}
return node->value;
}
bool contains(const Key& key)
{
return _find(_root, key) != nullptr;
}
Value& operator[](const Key& key)
{
auto node = _find(_root, key);
if (! node)
{
node = new Node(key, Value());
_root = _insert(_root, node);
}
return node->value;
}
const Key& floor(const Key& key) const
{
auto node = _floor(_root, key);
if (! node)
{
throw std::runtime_error("No floor for given key!");
}
return node->key;
}
const Key& ceiling(const Key& key) const
{
auto node = _ceiling(_root, key);
if (! node)
{
throw std::runtime_error("No ceiling for given key!");
}
return node->key;
}
size_t rank(const Key& key) const
{
return _rank(_root, key);
}
size_t size(const Key& first, const Key& second) const
{
return rank(second) - rank(first) + 1;
}
size_t size() const
{
return _root ? _root->size : 0;
}
bool is_empty() const
{
return size() == 0;
}
private:
enum class Color { Red, Black };
struct Node
{
Node(const Key& key_,
const Value& value_ = Value(),
Node* left_ = nullptr,
Node* right_ = nullptr,
Color color_ = Color::Red)
: key(key_)
, value(value_)
, left(left_)
, right(right_)
, color(color_)
{
resize();
}
void resize()
{
size = 1;
if (left) size += left->size;
if (right) size += right->size;
}
Key key;
Value value;
Node* left;
Node* right;
size_t size;
Color color;
};
Node* _insert(Node* node, const Key& key, const Value& value)
{
if (! node) return new Node(key, value);
if (key < node->key)
{
node->left = _insert(node->left, key, value);
}
else if (key > node->key)
{
node->right = _insert(node->right, key, value);
}
else node->value = value;
return _handle_colors(node);
}
Node* _insert(Node* node, Node* new_node)
{
if (! node) return new_node;
if (new_node->key < node->key)
{
node->left = _insert(node->left, new_node);
}
else if (new_node->key > node->key)
{
node->right = _insert(node->right, new_node);
}
else
{
delete node;
node = new_node;
}
return _handle_colors(node);
}
Node* _erase(Node* node, const Key& key)
{
if (! node) throw std::invalid_argument("No such key!");
if (key < node->key)
{
node->left = _erase(node->left, key);
}
else if (key > node->key)
{
node->right = _erase(node->right, key);
}
else
{
if (! node->left)
{
auto right = node->right;
delete node;
return right;
}
else if (! node->right)
{
auto left = node->left;
delete node;
return left;
}
Node* previous = nullptr;
Node* successor = node->right;
while (node->left)
{
if (previous) previous->size--;
previous = successor;
successor = successor->left;
}
if (previous) previous->left = successor->right;
else node->right = successor->right;
successor->left = node->left;
successor->right = node->right;
delete node;
node = successor;
}
return _handle_colors(node);
}
Node* _find(Node* node, const Key& key) const
{
if (! node) return nullptr;
if (key < node->key) return _find(node->left, key);
else if (key > node->key) return _find(node->right, key);
else return node;
}
void _clear(Node* node)
{
if (! node) return;
_clear(node->left);
_clear(node->right);
delete node;
}
Node* _copy(Node* other)
{
if (! other) return nullptr;
auto node = new Node(other->key, other->value);
node->left = _copy(other->left);
node->right = _copy(other->right);
return node;
}
Node* _ceiling(Node* node, const Key& key) const
{
if (! node) return nullptr;
if (key < node->key)
{
auto result = _ceiling(node->left, key);
return result ? result : node;
}
else return _ceiling(node->right, key);
}
Node* _floor(Node* node, const Key& key) const
{
if (! node) return nullptr;
if (key > node->key)
{
auto result = _floor(node->right, key);
return result ? result : node;
}
else return _floor(node->left, key);
}
size_t _rank(Node* node, const Key& key) const
{
if (! node) return 0;
if (key < node->key) return _rank(node->left, key);
size_t rank = 1;
if (node->left) rank += _rank(node->left, key);
rank += _rank(node->right, key);
return rank;
}
Node* _rotate_left(Node* node)
{
assert(_is_red(node->right));
auto right = node->right;
node->right = right->left;
right->left = node;
right->color = node->color;
node->color = Color::Red;
node->resize();
return right;
}
Node* _rotate_right(Node* node)
{
assert(_is_red(node->left));
assert(_is_red(node->left->left));
auto left = node->left;
node->left = left->right;
left->right = node;
left->color = node->color;
node->color = Color::Red;
node->resize();
return left;
}
Node* _flip_colors(Node* node)
{
assert(_is_red(node->left));
assert(_is_red(node->right));
node->color = Color::Red;
node->left->color = Color::Black;
node->left->color = Color::Black;
return node;
}
Node* _handle_colors(Node* node)
{
if (_is_red(node->right))
{
node = _rotate_left(node);
}
if (_is_red(node->left) && _is_red(node->left->left))
{
node = _rotate_right(node);
}
if (_is_red(node->left) && _is_red(node->right))
{
node = _flip_colors(node);
}
node->resize();
return node;
}
bool _is_red(Node* node) const
{
return node && node->color == Color::Red;
}
Node* _root;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment