Created
November 5, 2015 00:56
-
-
Save goldsborough/529689a484b58b913c94 to your computer and use it in GitHub Desktop.
Red-black Tree <3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template<typename Key, typename Value> | |
class RBTree | |
{ | |
public: | |
using size_t = std::size_t; | |
RBTree() noexcept | |
: _root(nullptr) | |
{ } | |
RBTree(const std::initializer_list<std::pair<Key, Value>>& list) | |
: _root(nullptr) | |
{ | |
for (const auto& item : list) | |
{ | |
insert(item.first, item.second); | |
} | |
} | |
RBTree(const RBTree& other) | |
{ | |
_root = _copy(other._root); | |
} | |
RBTree(RBTree&& other) noexcept | |
: RBTree() | |
{ | |
swap(other); | |
} | |
RBTree& operator=(RBTree other) | |
{ | |
swap(other); | |
return *this; | |
} | |
void swap(RBTree& other) noexcept | |
{ | |
// Enable ADL | |
using std::swap; | |
swap(_root, other._root); | |
} | |
friend void swap(RBTree& first, RBTree& second) noexcept | |
{ | |
first.swap(second); | |
} | |
~RBTree() | |
{ | |
_clear(_root); | |
} | |
void insert(const Key& key, const Value& value) | |
{ | |
_root = _insert(_root, key, value); | |
} | |
void erase(const Key& key) | |
{ | |
_root = _erase(_root, key); | |
} | |
void clear() | |
{ | |
_clear(_root); | |
} | |
Value& get(const Key& key) | |
{ | |
auto node = _find(_root, key); | |
if (! node) | |
{ | |
throw std::invalid_argument("No such key!"); | |
} | |
return node->value; | |
} | |
const Value& get(const Key& key) const | |
{ | |
auto node = _find(_root, key); | |
if (! node) | |
{ | |
throw std::invalid_argument("No such key!"); | |
} | |
return node->value; | |
} | |
bool contains(const Key& key) | |
{ | |
return _find(_root, key) != nullptr; | |
} | |
Value& operator[](const Key& key) | |
{ | |
auto node = _find(_root, key); | |
if (! node) | |
{ | |
node = new Node(key, Value()); | |
_root = _insert(_root, node); | |
} | |
return node->value; | |
} | |
const Key& floor(const Key& key) const | |
{ | |
auto node = _floor(_root, key); | |
if (! node) | |
{ | |
throw std::runtime_error("No floor for given key!"); | |
} | |
return node->key; | |
} | |
const Key& ceiling(const Key& key) const | |
{ | |
auto node = _ceiling(_root, key); | |
if (! node) | |
{ | |
throw std::runtime_error("No ceiling for given key!"); | |
} | |
return node->key; | |
} | |
size_t rank(const Key& key) const | |
{ | |
return _rank(_root, key); | |
} | |
size_t size(const Key& first, const Key& second) const | |
{ | |
return rank(second) - rank(first) + 1; | |
} | |
size_t size() const | |
{ | |
return _root ? _root->size : 0; | |
} | |
bool is_empty() const | |
{ | |
return size() == 0; | |
} | |
private: | |
enum class Color { Red, Black }; | |
struct Node | |
{ | |
Node(const Key& key_, | |
const Value& value_ = Value(), | |
Node* left_ = nullptr, | |
Node* right_ = nullptr, | |
Color color_ = Color::Red) | |
: key(key_) | |
, value(value_) | |
, left(left_) | |
, right(right_) | |
, color(color_) | |
{ | |
resize(); | |
} | |
void resize() | |
{ | |
size = 1; | |
if (left) size += left->size; | |
if (right) size += right->size; | |
} | |
Key key; | |
Value value; | |
Node* left; | |
Node* right; | |
size_t size; | |
Color color; | |
}; | |
Node* _insert(Node* node, const Key& key, const Value& value) | |
{ | |
if (! node) return new Node(key, value); | |
if (key < node->key) | |
{ | |
node->left = _insert(node->left, key, value); | |
} | |
else if (key > node->key) | |
{ | |
node->right = _insert(node->right, key, value); | |
} | |
else node->value = value; | |
return _handle_colors(node); | |
} | |
Node* _insert(Node* node, Node* new_node) | |
{ | |
if (! node) return new_node; | |
if (new_node->key < node->key) | |
{ | |
node->left = _insert(node->left, new_node); | |
} | |
else if (new_node->key > node->key) | |
{ | |
node->right = _insert(node->right, new_node); | |
} | |
else | |
{ | |
delete node; | |
node = new_node; | |
} | |
return _handle_colors(node); | |
} | |
Node* _erase(Node* node, const Key& key) | |
{ | |
if (! node) throw std::invalid_argument("No such key!"); | |
if (key < node->key) | |
{ | |
node->left = _erase(node->left, key); | |
} | |
else if (key > node->key) | |
{ | |
node->right = _erase(node->right, key); | |
} | |
else | |
{ | |
if (! node->left) | |
{ | |
auto right = node->right; | |
delete node; | |
return right; | |
} | |
else if (! node->right) | |
{ | |
auto left = node->left; | |
delete node; | |
return left; | |
} | |
Node* previous = nullptr; | |
Node* successor = node->right; | |
while (node->left) | |
{ | |
if (previous) previous->size--; | |
previous = successor; | |
successor = successor->left; | |
} | |
if (previous) previous->left = successor->right; | |
else node->right = successor->right; | |
successor->left = node->left; | |
successor->right = node->right; | |
delete node; | |
node = successor; | |
} | |
return _handle_colors(node); | |
} | |
Node* _find(Node* node, const Key& key) const | |
{ | |
if (! node) return nullptr; | |
if (key < node->key) return _find(node->left, key); | |
else if (key > node->key) return _find(node->right, key); | |
else return node; | |
} | |
void _clear(Node* node) | |
{ | |
if (! node) return; | |
_clear(node->left); | |
_clear(node->right); | |
delete node; | |
} | |
Node* _copy(Node* other) | |
{ | |
if (! other) return nullptr; | |
auto node = new Node(other->key, other->value); | |
node->left = _copy(other->left); | |
node->right = _copy(other->right); | |
return node; | |
} | |
Node* _ceiling(Node* node, const Key& key) const | |
{ | |
if (! node) return nullptr; | |
if (key < node->key) | |
{ | |
auto result = _ceiling(node->left, key); | |
return result ? result : node; | |
} | |
else return _ceiling(node->right, key); | |
} | |
Node* _floor(Node* node, const Key& key) const | |
{ | |
if (! node) return nullptr; | |
if (key > node->key) | |
{ | |
auto result = _floor(node->right, key); | |
return result ? result : node; | |
} | |
else return _floor(node->left, key); | |
} | |
size_t _rank(Node* node, const Key& key) const | |
{ | |
if (! node) return 0; | |
if (key < node->key) return _rank(node->left, key); | |
size_t rank = 1; | |
if (node->left) rank += _rank(node->left, key); | |
rank += _rank(node->right, key); | |
return rank; | |
} | |
Node* _rotate_left(Node* node) | |
{ | |
assert(_is_red(node->right)); | |
auto right = node->right; | |
node->right = right->left; | |
right->left = node; | |
right->color = node->color; | |
node->color = Color::Red; | |
node->resize(); | |
return right; | |
} | |
Node* _rotate_right(Node* node) | |
{ | |
assert(_is_red(node->left)); | |
assert(_is_red(node->left->left)); | |
auto left = node->left; | |
node->left = left->right; | |
left->right = node; | |
left->color = node->color; | |
node->color = Color::Red; | |
node->resize(); | |
return left; | |
} | |
Node* _flip_colors(Node* node) | |
{ | |
assert(_is_red(node->left)); | |
assert(_is_red(node->right)); | |
node->color = Color::Red; | |
node->left->color = Color::Black; | |
node->left->color = Color::Black; | |
return node; | |
} | |
Node* _handle_colors(Node* node) | |
{ | |
if (_is_red(node->right)) | |
{ | |
node = _rotate_left(node); | |
} | |
if (_is_red(node->left) && _is_red(node->left->left)) | |
{ | |
node = _rotate_right(node); | |
} | |
if (_is_red(node->left) && _is_red(node->right)) | |
{ | |
node = _flip_colors(node); | |
} | |
node->resize(); | |
return node; | |
} | |
bool _is_red(Node* node) const | |
{ | |
return node && node->color == Color::Red; | |
} | |
Node* _root; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment