Skip to content

Instantly share code, notes, and snippets.

@goldsborough
Created October 5, 2015 16:28
Show Gist options
  • Save goldsborough/29dc6a8234a02d614c85 to your computer and use it in GitHub Desktop.
Save goldsborough/29dc6a8234a02d614c85 to your computer and use it in GitHub Desktop.
Optimized, colorless Red-Black tree which stores color information in relative node-ordering
template<typename Key, typename Value>
class RedBlackTree
{
public:
using size_t = std::size_t;
RedBlackTree()
: _root(nullptr)
, _size(0)
{ }
RedBlackTree(std::initializer_list<std::pair<Key, Value>> items)
: RedBlackTree()
{
for (const auto& pair : items)
{
insert(pair.first, pair.second);
}
}
RedBlackTree(const RedBlackTree& other)
: _root(nullptr)
, _size(other._size)
{
_root = _copy(other._root);
}
RedBlackTree(RedBlackTree&& other)
: RedBlackTree()
{
swap(other);
}
RedBlackTree& operator=(RedBlackTree other)
{
swap(other);
return *this;
}
void swap(RedBlackTree& other)
{
// Enable ADL
using std::swap;
swap(_root, other._root);
swap(_size, other._size);
}
friend void swap(RedBlackTree& first, RedBlackTree& second)
{
first.swap(second);
}
~RedBlackTree()
{
_clear(_root);
}
void insert(const Key& key, const Value& value)
{
_root = _insert(_root, key, value);
}
Value& at(const Key& key)
{
Node* const& node = _find(_root, key);
if (! node) throw std::invalid_argument("Key not found!");
return node->value;
}
const Value& at(const Key& key) const
{
const Node* const& node = _find(_root, key);
if (! node) throw std::invalid_argument("Key not found!");
return node->value;
}
bool contains(const Key& key) const
{
return _find(_root, key) != nullptr;
}
void erase(const Key& key)
{
_root = _erase(_root, key);
--_size;
}
void clear()
{
_clear(_root);
_root = nullptr;
_size = 0;
}
Value& operator[](const Key& key)
{
Node* node = _find(_root, key);
if (! node)
{
node = new Node(key);
_root = _insert(_root, node);
}
return node->value;
}
size_t size() const
{
return _size;
}
const Key& floor(const Key& key) const
{
return _floor(_root, key);
}
const Key& ceiling(const Key& key) const
{
return _ceiling(_root, key);
}
size_t rank(const Key& key) const
{
return _rank(_root, key);
}
size_t range_count(const Key& lower, const Key& upper) const
{
return rank(upper) - rank(lower) + 1;
}
template<typename InputIterator>
InputIterator range_search(const Key& lower,
const Key& upper,
InputIterator destination) const
{
_range_search(_root, lower, upper, destination);
return destination;
}
bool is_empty() const
{
return _size == 0;
}
private:
struct Node
{
Node(const Key& key_ = Key(), const Value& value_ = Value())
: key(key_)
, value(value_)
, left(nullptr)
, right(nullptr)
, size(1)
{ }
Node(const Node& other)
: key(other.key)
, value(other.value)
, size(other.size)
, left(nullptr)
, right(nullptr)
{ }
void resize()
{
size = 1;
if (left) size += left->size;
if (right) size += right->size;
}
Key key;
Value value;
Node* left;
Node* right;
size_t size;
};
using ordered_t = std::pair<Node*&, Node*&>;
Node* _insert(Node* node, const Key& key, const Value& value)
{
if (! node)
{
++_size;
return new Node(key, value);
}
auto ordered = _ordered(node);
if (key < node->key)
{
ordered.first = _insert(ordered.first, key, value);
}
else if (key > node->key)
{
ordered.second = _insert(ordered.second, key, value);
}
else node->value = value;
node->resize();
return _handle_colors(node);
}
Node* _insert(Node* node, Node* new_node)
{
if (! node)
{
++_size;
node = new_node;
return node;
}
auto ordered = _ordered(node);
if (new_node->key < node->key)
{
ordered.first = _insert(ordered.first, new_node);
}
else if (new_node->key > node->key)
{
ordered.second = _insert(ordered.second, new_node);
}
else
{
node->value = new_node->value;
delete new_node;
}
node->resize();
return _handle_colors(node);
}
Node* const& _find(Node* const& node, const Key& key) const
{
if (! node) return node;
auto ordered = _ordered(node);
if (key < node->key) return _find(ordered.first, key);
else if (key > node->key) return _find(ordered.second, key);
else return node;
}
Node* _erase(Node* node, const Key& key)
{
if (! node) throw std::invalid_argument("Key not found!");
auto ordered = _ordered(node);
if (key < node->key)
{
ordered.first = _erase(ordered.first, key);
}
else if (key > node->key)
{
ordered.second = _erase(ordered.second, key);
}
else
{
if (! node->left)
{
delete node;
return node->right;
}
else if (! node->right)
{
delete node;
return node->left;
}
Node* previous = nullptr;
Node* successor = node->right;
while (successor->left)
{
previous = successor;
previous->size--;
successor = successor->left;
}
if (previous) previous->left = successor->right;
else node->right = successor->right;
_swap(node, successor);
delete node;
node = successor;
}
node->resize();
return _handle_colors(node);
}
void _clear(Node* node)
{
if (! node) return;
_clear(node->left);
_clear(node->right);
delete node;
}
const Node* _floor(const Node* node, const Key& key) const
{
if (! node) return nullptr;
auto ordered = _ordered(node);
if (key > node->key)
{
Node* result = _floor(node->second, key);
return result ? result : node;
}
else return _floor(ordered.first, key);
}
const Node* _ceiling(const Node* node, const Key& key) const
{
if (! node) return nullptr;
auto ordered = _ordered(node);
if (key < node->key)
{
Node* result = _ceiling(ordered.first, key);
return result ? result : node;
}
else return _ceiling(ordered.second, key);
}
size_t _rank(const Node* node, const Key& key) const
{
if (! node) return 0;
auto ordered = _ordered(node);
if (key < node->key) return _rank(ordered.first, key);
size_t size = 1;
if (ordered.first) size += ordered.first.size;
if (key > node->key) size += _rank(ordered.second, key);
return size;
}
template<typename InputIterator>
void _range_search(const Node* node,
const Key& lower,
const Key& upper,
InputIterator destination) const
{
if (! node) return;
auto ordered = _ordered(node);
if (node->key > lower)
{
_range_search(ordered.first, lower, upper, destination);
}
if (node->key >= lower && node->key <= upper)
{
*destination++ = node->key;
}
if (node->key < upper)
{
_range_search(ordered.second, lower, upper, destination);
}
}
Node* _copy(Node* other)
{
if (! other) return nullptr;
Node* node = new Node(*other);
node->left = _copy(other->left);
node->right = _copy(other->right);
return node;
}
Node* _rotate_left(Node* node)
{
assert(_is_red(node->right));
Node* right = node->right;
_swap(right);
node->right = right->left;
right->left = node;
_swap(node);
node->resize();
return right;
}
Node* _rotate_right(Node* node)
{
assert(_is_red(node->left));
assert(_is_red(node->left->left));
_swap(node);
_swap(node->right);
return node;
}
Node* _flip_colors(Node* node)
{
assert(_is_red(node->left));
assert(_is_red(node->right));
// Swap the red nodes' points to make them black
// they're red before, i.e. their left nodes
// contain the greater value and their right
// nodes the smaller one, to change color just swap
_swap(node->left);
_swap(node->right);
// Then also swap the nodes of the to-be red node
_swap(node);
return node;
}
Node* _handle_colors(Node* node)
{
if (_is_red(node->right))
{
node = _rotate_left(node);
}
if (_is_red(node->left) && _is_red(node->left->left))
{
node = _rotate_right(node);
}
if (_is_red(node->left) && _is_red(node->right))
{
node = _flip_colors(node);
}
return node;
}
bool _is_red(Node* node) const
{
if (! node) return false;
else if (node->left) return node->left->value > node->value;
else if (node->right) return node->right->value < node->value;
else return true;
}
ordered_t _ordered(Node* node) const
{
if (_is_red(node)) return {node->right, node->left};
return {node->left, node->right};
}
void _swap(Node* node)
{
_swap(node->left, node->right);
}
void _swap(Node* first, Node* second)
{
Node* temp = first->left;
first->left= second->left;
second->left = temp;
temp = first->right;
first->right = second->right;
second->right = temp;
}
Node* _root;
size_t _size;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment