Created
October 27, 2015 00:32
-
-
Save goldsborough/39194d387d433ad78d9b to your computer and use it in GitHub Desktop.
A Red Black Tree in C++.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template<typename Key, typename Value> | |
class RedBlackTree | |
{ | |
public: | |
using size_t = std::size_t; | |
RedBlackTree() | |
: _root(nullptr) | |
{ } | |
RedBlackTree(std::initializer_list<std::pair<Key, Value>> list) | |
: _root(nullptr) | |
{ | |
for (const auto& item : list) | |
{ | |
insert(item.first, item.second); | |
} | |
} | |
RedBlackTree(const RedBlackTree& other) | |
: _root(_copy(other._root)) | |
{ } | |
RedBlackTree(RedBlackTree&& other) noexcept | |
: RedBlackTree() | |
{ | |
swap(other); | |
} | |
RedBlackTree*& operator=(RedBlackTree other) | |
{ | |
swap(other); | |
return *this; | |
} | |
void swap(RedBlackTree& other) noexcept | |
{ | |
// Enable Argument-Dependent-Lookup (ADL) | |
using std::swap; | |
swap(_root, other._root); | |
} | |
friend void swap(RedBlackTree& first, RedBlackTree& second) | |
{ | |
first.swap(second); | |
} | |
~RedBlackTree() | |
{ | |
clear(); | |
} | |
void insert(const Key& key, const Value& value) | |
{ | |
_root = _insert(_root, key, value); | |
} | |
bool contains(const Key& key) const | |
{ | |
return _find(_root, key) != nullptr; | |
} | |
const Value& get(const Key& key) const | |
{ | |
return _get(key); | |
} | |
Value& get(const Key& key) | |
{ | |
return _get(key); | |
} | |
void erase(const Key& key) | |
{ | |
_root = _erase(_root, key); | |
} | |
void clear() | |
{ | |
_clear(_root); | |
} | |
Value& operator[](const Key& key) | |
{ | |
auto found = _find(_root, key); | |
if (! found) | |
{ | |
found = new Node(key); | |
_root = _insert(_root, found); | |
} | |
return found->value; | |
} | |
size_t rank(const Key& key) const | |
{ | |
return _rank(_root, key); | |
} | |
const Key& floor(const Key& key) const | |
{ | |
auto node = _floor(_root, key); | |
if (! node) | |
{ | |
throw std::invalid_argument("No floor for this key!"); | |
} | |
return node->key; | |
} | |
const Key& ceiling(const Key& key) const | |
{ | |
auto node = _ceiling(_root, key); | |
if (! node) | |
{ | |
throw std::invalid_argument("No ceiling for this key!"); | |
} | |
return node->key; | |
} | |
size_t size() const | |
{ | |
return _root ? _root->size : 0; | |
} | |
bool is_empty() const | |
{ | |
return _root == nullptr; | |
} | |
private: | |
enum class Color { Red, Black }; | |
struct Node | |
{ | |
Node(const Key& key_ = Key(), | |
const Value& value_ = Value(), | |
Node* left_ = nullptr, | |
Node* right_ = nullptr, | |
Color color_ = Color::Red) | |
: key(key_) | |
, value(value_) | |
, left(left_) | |
, right(right_) | |
, size(1) | |
, color(color_) | |
{ } | |
void resize() | |
{ | |
size = 1; | |
if (left) size += left->size; | |
if (right) size += right->size; | |
} | |
Key key; | |
Value value; | |
Node* left; | |
Node* right; | |
size_t size; | |
Color color; | |
}; | |
Value& _get(const Key& key) const | |
{ | |
auto found = _find(_root, key); | |
if (! found) | |
{ | |
throw std::invalid_argument("No such key!"); | |
} | |
return found->value; | |
} | |
Node* _insert(Node* node, const Key& key, const Value& value) | |
{ | |
if (! node) return new Node(key, value); | |
if (key < node->key) | |
{ | |
node->left = _insert(node->left, key, value); | |
} | |
else if (key > node->key) | |
{ | |
node->right = _insert(node->right, key, value); | |
} | |
else node->value = value; | |
return _handle_colors(node); | |
} | |
Node* _insert(Node* node, Node* new_node) | |
{ | |
if (! node) return new_node; | |
if (new_node->key < node->key) | |
{ | |
node->left = _insert(node->left, new_node); | |
} | |
else if (new_node->key > node->key) | |
{ | |
node->right = _insert(node->right, new_node); | |
} | |
return _handle_colors(node); | |
} | |
Node* _find(Node* node, const Key& key) const | |
{ | |
if (! node) return node; | |
if (key < node->key) return _find(node->left, key); | |
else if (key > node->key) return _find(node->right, key); | |
else return node; | |
} | |
Node* _erase(Node* node, const Key& key) | |
{ | |
if (! node) | |
{ | |
throw std::invalid_argument("No such key to erase!"); | |
} | |
if (key < node->key) | |
{ | |
node->left = _erase(node->left, key); | |
} | |
else if (key > node->key) | |
{ | |
node->right = _erase(node->right, key); | |
} | |
else | |
{ | |
if (! node->left) | |
{ | |
auto successor = node->right; | |
delete node; | |
return successor; | |
} | |
else if (! node->right) | |
{ | |
auto successor = node->left; | |
delete node; | |
return successor; | |
} | |
Node* previous = nullptr; | |
auto successor = node->right; | |
while (node->left) | |
{ | |
previous = node; | |
previous->size--; | |
node = node->left; | |
} | |
if (previous) | |
{ | |
previous->left = successor->right; | |
successor->right = node->right; | |
} | |
successor->left = node->left; | |
delete node; | |
node = successor; | |
} | |
return _handle_colors(node); | |
} | |
size_t _rank(Node* node, const Key& key) const | |
{ | |
if (! node) return 0; | |
if (key < node->key) return _rank(node->left, key); | |
size_t size = 1; | |
if (node->left) size += node->left->size; | |
if (key > node->key) size += _rank(node->right, key); | |
return size; | |
} | |
Node* _floor(Node* node, const Key& key) const | |
{ | |
if (! node) return nullptr; | |
if (key > node->key) | |
{ | |
auto found = _floor(node->right, key); | |
return found ? found : node; | |
} | |
else return _floor(node->left, key); | |
} | |
Node* _ceiling(Node* node, const Key& key) const | |
{ | |
if (! node) return nullptr; | |
if (key < node->key) | |
{ | |
auto found = _ceiling(node->left, key); | |
return found ? found : node; | |
} | |
else return _ceiling(node->right, key); | |
} | |
void _clear(Node* node) | |
{ | |
if (! node) return; | |
_clear(node->left); | |
_clear(node->right); | |
delete node; | |
} | |
Node* _copy(Node* other) | |
{ | |
if (! other) return nullptr; | |
auto node = new Node(other->key, other->value); | |
node->left = _copy(other->left); | |
node->right = _copy(other->right); | |
return node; | |
} | |
Node* _rotate_left(Node* node) | |
{ | |
assert(_is_red(node->right)); | |
auto right = node->right; | |
node->right = right->left; | |
right->left = node; | |
right->color = node->color; | |
node->color = Color::Red; | |
node->resize(); | |
return right; | |
} | |
Node* _rotate_right(Node* node) | |
{ | |
assert(_is_red(node->left)); | |
assert(_is_red(node->left->left)); | |
auto left = node->left; | |
node->left = left->right; | |
left->right = node; | |
left->color = node->color; | |
node->color = Color::Red; | |
node->resize(); | |
return left; | |
} | |
void _color_flip(Node* node) | |
{ | |
assert(_is_red(node->left)); | |
assert(_is_red(node->right)); | |
node->left->color = Color::Black; | |
node->right->color = Color::Black; | |
node->color = Color::Red; | |
} | |
Node* _handle_colors(Node* node) | |
{ | |
if (_is_red(node->right)) | |
{ | |
node = _rotate_left(node); | |
} | |
if (_is_red(node->left) && _is_red(node->left->left)) | |
{ | |
node = _rotate_right(node); | |
} | |
if (_is_red(node->left) && _is_red(node->right)) | |
{ | |
_color_flip(node); | |
} | |
node->resize(); | |
return node; | |
} | |
bool _is_red(Node* node) | |
{ | |
return node && node->color == Color::Red; | |
} | |
Node* _root; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment