Skip to content

Instantly share code, notes, and snippets.

@f0lie
Last active October 30, 2022 22:59
Show Gist options
  • Save f0lie/fc213e1047c4ca9f829b58c41f91091f to your computer and use it in GitHub Desktop.
Save f0lie/fc213e1047c4ca9f829b58c41f91091f to your computer and use it in GitHub Desktop.
C++ 11 AVL Tree with Unique Pointers
#include <iostream>
#include <vector>
#include <exception>
#include <memory>
#include <algorithm>
//https://gist.github.com/f0lie/fc213e1047c4ca9f829b58c41f91091f
template<typename Key, typename T>
//class bst {
private:
struct node {
Key key = Key{};
T data = T{};
std::unique_ptr<node> left = nullptr;
std::unique_ptr<node> right = nullptr;
int height = 0;
node() {}
node(Key k, T t) : key{k}, data{t} {}
node(Key k, T t, std::unique_ptr<node> left, std::unique_ptr<node> right, node* parent) : key{k}, data{t},
left{std::move(left)},
right{std::move(right)} {}
bool operator<(node& rhs) { return key < rhs.key; }
bool operator==(node& rhs) { return key == rhs.key; }
};
std::unique_ptr<node> root;
int height(const node* curr) const { return !curr ? -1 : curr->height; }
int balance_factor(const node* curr) const { return !curr ? 0 : height(curr->left.get()) - height(curr->right.get()); }
node* findMin(node* curr) const {
// Return the leftmost node aka the min node of the subtree
// https://gist.github.com/f0lie/fc213e1047c4ca9f829b58c41f91091f
while (curr->left) {
curr = curr->left.get();
}
return curr;
}
std::unique_ptr<node> balance(std::unique_ptr<node>& curr) {
// Balance a node if it's imbalanced
// https://gist.github.com/f0lie/fc213e1047c4ca9f829b58c41f91091f
int balance = balance_factor(curr.get());
// imbalance from the left
if (balance > 1) {
// imbalance from the left
if (balance_factor(curr->left.get()) > 0) {
curr = rightRotate(curr);
} else {
curr = leftRightRotate(curr);
}
// imbalance from the right
} else if (balance < -1) {
// imbalance from the left
if (balance_factor(curr->right.get()) > 0) {
curr = rightLeftRotate(curr);
} else {
curr = leftRotate(curr);
}
}
return std::move(curr);
}
std::unique_ptr<node> rightRotate(std::unique_ptr<node>& parent) {
// Move parent to the right, making it's left child the parent
// https://gist.github.com/f0lie/fc213e1047c4ca9f829b58c41f91091f
auto pivot = std::move(parent->left);
parent->left = std::move(pivot->right);
pivot->right = std::move(parent);
pivot->right->height = std::max(height(pivot->right->left.get()), height(pivot->right->right.get())) + 1;
pivot->height = std::max(height(pivot->left.get()), height(pivot->right.get())) + 1;
return pivot;
}
std::unique_ptr<node> leftRotate(std::unique_ptr<node>& parent) {
// Move parent to the left, making it's right child the parent
// https://gist.github.com/f0lie/fc213e1047c4ca9f829b58c41f91091f
auto pivot = std::move(parent->right);
parent->right = std::move(pivot->left);
pivot->left = std::move(parent);
pivot->left->height = std::max(height(pivot->left->left.get()), height(pivot->left->right.get())) + 1;
pivot->height = std::max(height(pivot->left.get()), height(pivot->right.get())) + 1;
return pivot;
}
std::unique_ptr<node> leftRightRotate(std::unique_ptr<node>& parent) {
// Move the left child to the the left then
// Move parent to the right, making it's left child the parent
// https://gist.github.com/f0lie/fc213e1047c4ca9f829b58c41f91091f
parent->left = leftRotate(parent->left);
return rightRotate(parent);
}
std::unique_ptr<node> rightLeftRotate(std::unique_ptr<node>& parent) {
// Move the right child tot he right then
// Move parent to the left, making it's right child the parent
parent->right = rightRotate(parent->right);
return leftRotate(parent);
}
void _erase(const Key& key, std::unique_ptr<node>& curr) {
// Removes key, throws expection if key is not found
// https://gist.github.com/f0lie/fc213e1047c4ca9f829b58c41f91091f
if (!curr) {
throw std::out_of_range("Erase: Key not found: "+std::to_string(key));
} else if (curr->key == key) {
// Case where key is found and has two children
if (curr->left && curr->right) {
auto temp = findMin(curr->right.get());
curr->data = temp->data;
curr->key = temp->key;
_erase(curr->key, curr->right);
// Has one children on the left
} else if (curr->left) {
curr = std::move(curr->left);
// Has one children on the right
} else if (curr->right) {
curr = std::move(curr->right);
// Leaf node so release the pointer
// This works because curr is a reference to the right/left pointer so I don't need to check for left or right
} else {
curr.release();
}
} else {
_erase(key, key < curr->key ? curr->left : curr->right);
}
if (curr) {
curr->height = std::max(height(curr->left.get()), height(curr->right.get())) + 1;
curr = balance(curr);
}
}
void _insert(std::unique_ptr<node> newNode, std::unique_ptr<node>& curr) {
// Insert a new node or update a node if the keys match
// https://gist.github.com/f0lie/fc213e1047c4ca9f829b58c41f91091f
// Add new leaf node
// I don't have to check for left/right pointers because curr is a reference to that very pointer
if (!curr) {
curr = std::move(newNode);
// Update the node if the keys match
} else if (newNode->key == curr->key) {
curr->data = newNode->data;
} else {
_insert(std::move(newNode), newNode->key < curr->key ? curr->left : curr->right);
}
curr->height = std::max(height(curr->left.get()), height(curr->right.get())) + 1;
curr = balance(curr);
}
void _inorder(std::vector<std::pair<Key, T>> & contents, node* node) const {
if (!node) return;
_inorder(contents, node->left.get());
contents.emplace_back(node->key, node->data);
_inorder(contents, node->right.get());
}
void _preorder(std::vector<std::pair<Key, T>> & contents, node* node) {
if (!node) return;
contents.emplace_back(node->key, node->data);
_preorder(contents, node->left.get());
_preorder(contents, node->right.get());
}
void _postorder(std::vector<std::pair<Key, T>> & contents, node* node) {
if (!node) return;
_postorder(contents, node->left.get());
_postorder(contents, node->right.get());
contents.emplace_back(node->key, node->data);
}
public:
void insert(const Key &key, const T &t) { _insert(std::make_unique<node>(key, t), root); }
void erase(const Key &key) { _erase(key, root); }
bool empty() const { return root == nullptr; }
int height() const { return root->height; }
T& at(const Key &key) const {
// Returns data at key, throws expection if key is not found
node* node = root.get();
while (node) {
if (node->key == key)
return node->data;
if (key > node->key)
node = node->right.get();
else
node = node->left.get();
}
throw std::out_of_range("At: Key not found");
}
bool contains(const Key &key) const {
// Returns a bool if the tree contains the key
try {
at(key);
} catch (std::out_of_range) {
return false;
}
return true;
}
size_t size() const {
std::vector<std::pair<Key, T>> contents;
_inorder(contents, root.get());
return contents.size();
}
std::vector<std::pair<Key, T>> inorder_contents() {
std::vector<std::pair<Key, T>> contents;
_inorder(contents, root.get());
return contents;
}
std::vector<std::pair<Key, T>> preorder_contents() {
std::vector<std::pair<Key, T>> contents;
_preorder(contents, root.get());
return contents;
}
std::vector<std::pair<Key, T>> postorder_contents() {
std::vector<std::pair<Key, T>> contents;
_postorder(contents, root.get());
return contents;
}
};
int main() {
// https://gist.github.com/f0lie/fc213e1047c4ca9f829b58c41f91091f
bst<int, int> tree;
tree.insert(2,1);
tree.insert(1,1);
tree.insert(9,1);
tree.insert(3,1);
tree.insert(8,1);
tree.insert(4,1);
tree.insert(10,1);
tree.insert(5,1);
tree.insert(6,1);
tree.insert(7,1);
std::vector<std::pair<int,int>> contents = tree.inorder_contents();
for (auto key: contents) {
std::cout << key.first << " ";
}
std::cout << tree.height();
tree.erase(2);
std::cout << tree.height();
tree.erase(1);
std::cout << tree.height();
tree.erase(9);
contents = tree.inorder_contents();
for (auto key: contents) {
std::cout << key.first << " ";
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment