Skip to content

Instantly share code, notes, and snippets.

@goldsborough
Created November 11, 2015 19:16
Show Gist options
  • Save goldsborough/2d0a601a66826c780fc8 to your computer and use it in GitHub Desktop.
Save goldsborough/2d0a601a66826c780fc8 to your computer and use it in GitHub Desktop.
A Red Black Tree.
#ifndef RED_BLACK_TREE_HPP
#define RED_BLACK_TREE_HPP
#include <assert.h>
#include <initializer_list>
#include <stdexcept>
template<typename Key, typename Value>
class RedBlackTree
{
public:
using size_t = std::size_t;
RedBlackTree() noexcept
: _root(nullptr)
, _size(0)
{ }
RedBlackTree(const std::initializer_list<std::pair<Key, Value>>& list)
: RedBlackTree()
{
for (const auto& item : list)
{
insert(item.first, item.second);
}
}
RedBlackTree(const RedBlackTree& other)
: _root(nullptr)
, _size(0)
{
_root = _copy(other._root);
}
RedBlackTree(RedBlackTree&& other) noexcept
: RedBlackTree()
{
swap(other);
}
RedBlackTree& operator=(RedBlackTree other)
{
swap(other);
return *this;
}
void swap(RedBlackTree& other)
{
using std::swap;
swap(_root, other._root);
swap(_size, other._size);
}
friend void swap(RedBlackTree& first,
RedBlackTree& second)
{
first.swap(second);
}
~RedBlackTree()
{
_clear(_root);
}
void insert(const Key& key, const Value& value)
{
_root = _insert(_root, key, value);
}
Value& get(const Key& key)
{
auto node = _find(_root, key);
if (! node)
{
throw std::invalid_argument("No such key!");
}
return node->value;
}
const Value& get(const Key& key) const
{
auto node = _find(_root, key);
if (! node)
{
throw std::invalid_argument("No such key!");
}
return node->value;
}
Value& operator[](const Key& key)
{
auto node = _find(_root, key);
if (! node)
{
node = new Node(key);
_root = _insert(_root, node);
}
return node->value;
}
void erase(const Key& key)
{
_root = _erase(_root, key);
}
void clear()
{
_clear(_root);
_size = 0;
}
size_t size() const
{
return _size;
}
bool is_empty() const
{
return _size == 0;
}
private:
enum class Color { Red, Black };
struct Node
{
Node(const Key& key_,
const Value& value_ = Value(),
Node* left_ = nullptr,
Node* right_ = nullptr,
Color color_ = Color::Red)
: key(key_)
, value(value_)
, left(left_)
, right(right_)
, color(color_)
{ }
Key key;
Value value;
Node* left;
Node* right;
Color color;
};
Node* _handle_colors(Node* node)
{
if (_is_red(node->right))
{
node = _rotate_left(node);
}
if (_is_red(node->left) && _is_red(node->left->left))
{
node = _rotate_right(node);
}
if (_is_red(node->left) && _is_red(node->right))
{
node = _flip_colors(node);
}
return node;
}
Node* _rotate_left(Node* node)
{
if (! node) return nullptr;
assert(_is_red(node->right));
auto right = node->right;
node->right = right->left;
right->left = node;
right->color = node->color;
node->color = Color::Red;
return right;
}
Node* _rotate_right(Node* node)
{
if (! node) return nullptr;
assert(_is_red(node->left));
assert(_is_red(node->left->left));
auto left = node->left;
node->left = left->right;
left->right = node;
left->color = node->color;
node->color = Color::Red;
return left;
}
Node* _flip_colors(Node* node)
{
if (! node) return nullptr;
assert(_is_red(node->left));
assert(_is_red(node->right));
node->left->color = Color::Black;
node->right->color = Color::Black;
node->color = Color::Red;
return node;
}
bool _is_red(Node* node)
{
return node && node->color == Color::Red;
}
Node* _insert(Node* node, const Key& key, const Value& value)
{
if (! node)
{
++_size;
return new Node(key, value);
}
if (key < node->key)
{
node->left = _insert(node->left, key, value);
}
else if (key > node->key)
{
node->right = _insert(node->right, key, value);
}
else node->value = value;
return _handle_colors(node);
}
Node* _insert(Node* node, Node* new_node)
{
if (! node)
{
++_size;
return new_node;
}
if (new_node->key < node->key)
{
node->left = _insert(node->left, new_node);
}
else if (new_node->key > node->key)
{
node->right = _insert(node->right, new_node);
}
else
{
delete node;
node = new_node;
}
return _handle_colors(node);
}
Node* _erase(Node* node, const Key& key)
{
if (! node) throw std::invalid_argument("No such key!");
if (key < node->key)
{
node->left = _erase(node->left, key);
}
else if (key > node->key)
{
node->right = _erase(node->right, key);
}
else
{
if (! node->left)
{
auto sucessor = node->right;
delete node;
--_size;
node = sucessor;
}
else if (! node->right)
{
auto sucessor = node->left;
delete node;
--_size;
node = sucessor;
}
else node = _get_successor(node);
}
return _handle_colors(node);
}
Node* _get_successor(Node* node)
{
Node* previous = nullptr;
Node* successor = node->right;
while (node->left)
{
previous = node;
node = node->left;
}
if (previous) node->left = successor->right;
else node->right = successor->right;
successor->left = node->left;
successor->right = node->right;
delete node;
--_size;
return successor;
}
Node* _find(Node* node, const Key& key) const
{
if (! node) return nullptr;
if (key < node->key) return _find(node->left, key);
else if (key > node->key) return _find(node->right, key);
else return node;
}
Node* _copy(Node* other_node)
{
if (! other_node) return nullptr;
auto node = new Node(other_node->key, other_node->value);
node->left = _copy(other_node->left);
node->right = _copy(other_node->right);
return node;
}
void _clear(Node* node)
{
if (! node) return;
_clear(node->left);
_clear(node->right);
delete node;
}
Node* _root;
size_t _size;
};
#endif /* RED_BLACK_TREE_HPP */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment