Skip to content

Instantly share code, notes, and snippets.

@piscisaureus
Created February 11, 2018 23:55
Show Gist options
  • Save piscisaureus/23c24af39a44eb8ba6a6eb715041706a to your computer and use it in GitHub Desktop.
Save piscisaureus/23c24af39a44eb8ba6a6eb715041706a to your computer and use it in GitHub Desktop.
Red-black tree
#include <assert.h>
#include <malloc.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
typedef intptr_t ssize_t;
typedef struct tree_node tree_node_t;
typedef struct tree_node {
tree_node_t* left;
tree_node_t* right;
tree_node_t* parent;
bool red;
uintptr_t key;
} tree_node_t;
typedef struct tree { tree_node_t* root; } tree_t;
static void _tree_rotate_left(tree_t* tree, tree_node_t* node) {
tree_node_t* p = node;
tree_node_t* q = node->right;
tree_node_t* parent = p->parent;
if (parent) {
if (parent->left == p)
parent->left = q;
else
parent->right = q;
} else {
tree->root = q;
}
q->parent = parent;
p->parent = q;
p->right = q->left;
if (p->right)
p->right->parent = p;
q->left = p;
}
static void _tree_rotate_right(tree_t* tree, tree_node_t* node) {
tree_node_t* p = node;
tree_node_t* q = node->left;
tree_node_t* parent = p->parent;
if (parent) {
if (parent->left == p)
parent->left = q;
else
parent->right = q;
} else {
tree->root = q;
}
q->parent = parent;
p->parent = q;
p->left = q->right;
if (p->left)
p->left->parent = p;
q->right = p;
}
void tree_init(tree_t* tree) {
memset(tree, 0, sizeof *tree);
}
void tree_node_init(tree_node_t* node) {
memset(node, 0, sizeof *node);
}
int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key) {
tree_node_t* parent;
tree_node_t* grandparent;
tree_node_t* uncle;
parent = tree->root;
if (parent) {
for (;;) {
if (key < parent->key) {
if (parent->left) {
parent = parent->left;
} else {
parent->left = node;
break;
}
} else if (key > parent->key) {
if (parent->right) {
parent = parent->right;
} else {
parent->right = node;
break;
}
} else {
return -1;
}
}
} else {
tree->root = node;
}
node->key = key;
node->left = node->right = NULL;
node->parent = parent;
node->red = true;
while (parent && parent->red) {
grandparent = parent->parent;
if (parent == grandparent->left) {
uncle = grandparent->right;
if (uncle && uncle->red) {
parent->red = uncle->red = false;
grandparent->red = true;
node = grandparent;
} else {
if (node == parent->right) {
_tree_rotate_left(tree, parent);
node = parent;
parent = node->parent;
}
parent->red = false;
grandparent->red = true;
_tree_rotate_right(tree, grandparent);
}
} else {
uncle = grandparent->left;
if (uncle && uncle->red) {
parent->red = uncle->red = false;
grandparent->red = true;
node = grandparent;
} else {
if (node == parent->left) {
_tree_rotate_right(tree, parent);
node = parent;
parent = node->parent;
}
parent->red = false;
grandparent->red = true;
_tree_rotate_left(tree, grandparent);
}
}
parent = node->parent;
}
tree->root->red = false;
return 0;
}
void tree_del(tree_t* tree, tree_node_t* node) {
tree_node_t* parent = node->parent;
tree_node_t* left = node->left;
tree_node_t* right = node->right;
tree_node_t* next;
tree_node_t* sibling;
bool red;
if (!left) {
next = right;
} else if (!right) {
next = left;
} else {
next = right;
while (next->left)
next = next->left;
}
if (parent) {
if (parent->left == node)
parent->left = next;
else
parent->right = next;
} else {
tree->root = next;
}
if (left && right) {
red = next->red;
next->red = node->red;
next->left = left;
left->parent = next;
if (next != right) {
parent = next->parent;
next->parent = node->parent;
node = next->right;
parent->left = node;
next->right = right;
right->parent = next;
} else {
next->parent = parent;
parent = next;
node = next->right;
}
} else {
red = node->red;
node = next;
}
if (node)
node->parent = parent;
if (red)
return;
if (node && node->red) {
node->red = false;
return;
}
do {
if (node == tree->root)
break;
if (node == parent->left) {
sibling = parent->right;
if (sibling->red) {
sibling->red = false;
parent->red = true;
_tree_rotate_left(tree, parent);
sibling = parent->right;
}
if ((sibling->left && sibling->left->red) ||
(sibling->right && sibling->right->red)) {
if (!sibling->right || !sibling->right->red) {
sibling->left->red = false;
sibling->red = true;
_tree_rotate_right(tree, sibling);
sibling = parent->right;
}
sibling->red = parent->red;
parent->red = sibling->right->red = false;
_tree_rotate_left(tree, parent);
node = tree->root;
break;
}
} else {
sibling = parent->left;
if (sibling->red) {
sibling->red = false;
parent->red = true;
_tree_rotate_right(tree, parent);
sibling = parent->left;
}
if ((sibling->left && sibling->left->red) ||
(sibling->right && sibling->right->red)) {
if (!sibling->left || !sibling->left->red) {
sibling->right->red = false;
sibling->red = true;
_tree_rotate_left(tree, sibling);
sibling = parent->left;
}
sibling->red = parent->red;
parent->red = sibling->left->red = false;
_tree_rotate_right(tree, parent);
node = tree->root;
break;
}
}
sibling->red = true;
node = parent;
parent = parent->parent;
} while (!node->red);
if (node)
node->red = false;
}
tree_node_t* tree_find(const tree_t* tree, uintptr_t key) {
tree_node_t* node = tree->root;
while (node) {
if (key < node->key)
node = node->left;
else if (key > node->key)
node = node->right;
else
return node;
}
return NULL;
}
tree_node_t* tree_root(const tree_t* tree) {
return tree->root;
}
static size_t check_subtree(const tree_node_t* node) {
size_t black_height_left;
size_t black_height_right;
if (!node)
return 0;
black_height_left = check_subtree(node->left);
black_height_right = check_subtree(node->right);
assert(black_height_left == black_height_right);
if (node->red) {
assert(!node->left || !node->left->red);
assert(!node->right || !node->right->red);
return black_height_left;
} else {
return black_height_left + 1;
}
}
static void check_tree(const tree_t* tree) {
check_subtree(tree_root(tree));
}
static size_t count_subtree(const tree_node_t* node) {
if (node == NULL)
return 0;
else
return 1 + count_subtree(node->left) + count_subtree(node->right);
}
static size_t count_tree(const tree_t* tree) {
return count_subtree(tree_root(tree));
}
static void check_tree_count(const tree_t* tree, size_t expected_count) {
size_t count = count_tree(tree);
assert(count == expected_count);
}
/* TESTS */
#define NODE_COUNT 1000
static_assert(NODE_COUNT <= RAND_MAX, "NODE_COUNT too high");
typedef void (*test_op_t)(tree_t* tree, uintptr_t key);
static void increasing(tree_t* tree, test_op_t op) {
ssize_t i;
for (i = 0; i < NODE_COUNT; i++)
op(tree, i);
}
static void decreasing(tree_t* tree, test_op_t op) {
ssize_t i;
for (i = NODE_COUNT - 1; i >= 0; i--)
op(tree, i);
}
static void random(tree_t* tree, test_op_t op) {
uintptr_t keys[NODE_COUNT];
uintptr_t index, key;
ssize_t left;
for (index = 0; index < NODE_COUNT; index++)
keys[index] = index;
for (left = NODE_COUNT - 1; left >= 0; left--) {
index = left > 0 ? rand() % left : 0;
key = keys[index];
keys[index] = keys[left];
op(tree, key);
};
}
static void add(tree_t* tree, uintptr_t key) {
tree_node_t* node;
size_t before_count;
int r;
before_count = count_tree(tree);
node = malloc(sizeof *node);
assert(node != NULL);
tree_node_init(node);
r = tree_add(tree, node, key);
assert(r == 0);
assert(node->key == key);
check_tree(tree);
check_tree_count(tree, before_count + 1);
}
static void add_error(tree_t* tree, uintptr_t key) {
tree_node_t node;
size_t before_count;
int r;
before_count = count_tree(tree);
tree_node_init(&node);
r = tree_add(tree, &node, key);
assert(r == -1);
check_tree(tree);
check_tree_count(tree, before_count);
}
static void find_del(tree_t* tree, uintptr_t key) {
tree_node_t* node;
size_t before_count;
before_count = count_tree(tree);
node = tree_find(tree, key);
assert(node != NULL);
assert(node->key == key);
tree_del(tree, node);
free(node);
check_tree(tree);
check_tree_count(tree, before_count - 1);
}
static void find_error(tree_t* tree, uintptr_t key) {
tree_node_t* node;
size_t before_count;
before_count = count_tree(tree);
node = tree_find(tree, key);
assert(node == NULL);
check_tree(tree);
check_tree_count(tree, before_count);
}
int main(void) {
tree_t tree;
tree_init(&tree);
increasing(&tree, add);
check_tree_count(&tree, NODE_COUNT);
increasing(&tree, add_error);
increasing(&tree, find_del);
check_tree_count(&tree, 0);
increasing(&tree, find_error);
decreasing(&tree, add);
check_tree_count(&tree, NODE_COUNT);
decreasing(&tree, add_error);
decreasing(&tree, find_del);
check_tree_count(&tree, 0);
decreasing(&tree, find_error);
random(&tree, add);
check_tree_count(&tree, NODE_COUNT);
random(&tree, add_error);
random(&tree, find_del);
check_tree_count(&tree, 0);
random(&tree, find_error);
random(&tree, add);
check_tree_count(&tree, NODE_COUNT);
increasing(&tree, add_error);
decreasing(&tree, find_del);
check_tree_count(&tree, 0);
random(&tree, find_error);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment