Last active
November 28, 2017 15:32
-
-
Save martinkunev/720386 to your computer and use it in GitHub Desktop.
AVL tree (C implementation)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Conquest of Levidon | |
* Copyright (C) 2017 Martin Kunev <[email protected]> | |
* | |
* This file is part of Conquest of Levidon. | |
* | |
* Conquest of Levidon is free software: you can redistribute it and/or modify | |
* it under the terms of the GNU General Public License as published by | |
* the Free Software Foundation version 3 of the License. | |
* | |
* Conquest of Levidon is distributed in the hope that it will be useful, | |
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
* GNU General Public License for more details. | |
* | |
* You should have received a copy of the GNU General Public License | |
* along with Conquest of Levidon. If not, see <http://www.gnu.org/licenses/>. | |
*/ | |
#if !defined(avl_type) | |
# define avl_type int | |
#endif | |
// Returns positive number if a is before b, 0 if a == b, negative number if a is after b. | |
//#define avl_compare(a, a_size, b) (*(int *)(b)->key_data - *(int *)(a)) /* TODO this could overflow/underflow */ | |
#define avl_compare(a, a_size, b) memcmp((b)->key_data, (a), (a_size)) | |
// Callback for updating aggregated values for a node. When called, both children of the passed node are in a consistent state. | |
#define avl_update(node) (void)0 | |
// TODO make the factor calculation code readable | |
// TODO implement union, intersection, difference | |
// TODO test avl_update() with range searches | |
struct avl | |
{ | |
size_t count; | |
struct avl_node | |
{ | |
struct avl_node *next[2]; | |
const size_t key_size; | |
avl_type value; | |
signed char factor; | |
const unsigned char key_data[]; | |
} *root; | |
}; | |
// A newly created struct avl must be initialized with zeroes: struct avl avl = {0}; | |
avl_type *avl_get(struct avl *avl, const unsigned char *restrict key_data, size_t key_size); | |
avl_type *avl_insert(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, avl_type value); | |
void avl_remove(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, avl_type *value_old); | |
void avl_iterate(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, int (*callback)(struct avl_node *, void *), void *argument); | |
void avl_term(struct avl *avl); | |
/* avl.c */ | |
#include <assert.h> | |
#include <stdio.h> | |
#include <stdlib.h> | |
#include <string.h> | |
#include <unistd.h> | |
struct avl_node_mutable | |
{ | |
struct avl_node *next[2]; | |
size_t key_size; | |
avl_type value; | |
signed char factor; | |
unsigned char key_data[]; | |
}; | |
static avl_type *avl_get_internal(struct avl_node *t, const unsigned char *restrict key_data, size_t key_size) | |
{ | |
while (t) | |
{ | |
int diff = avl_compare(key_data, key_size, t); | |
if (!diff) return &t->value; | |
t = t->next[diff < 0]; | |
} | |
return 0; | |
} | |
avl_type *avl_get(struct avl *avl, const unsigned char *restrict key_data, size_t key_size) | |
{ | |
return avl_get_internal(avl->root, key_data, key_size); | |
} | |
static signed char avl_balance(struct avl_node *restrict *restrict branch) | |
{ | |
struct avl_node *node = *branch; | |
unsigned char index = (node->factor < 0); | |
struct avl_node *child = node->next[index]; | |
signed char factor_sign = (int [2]){1, -1}[index]; | |
assert(node->factor); | |
// Rotate the longer subtree if necessary. | |
if (factor_sign * child->factor < 0) | |
{ | |
struct avl_node *grandchild = child->next[index ^ 1]; | |
node->next[index] = grandchild; | |
child->next[index ^ 1] = grandchild->next[index]; | |
grandchild->next[index] = child; | |
// Set the balance factor | |
child->factor = factor_sign * (factor_sign * grandchild->factor < 0); // TODO fix this code | |
grandchild->factor = factor_sign * (factor_sign * grandchild->factor > 0) + factor_sign; // TODO fix this code | |
avl_update(child); | |
child = grandchild; | |
} | |
*branch = child; | |
node->next[index] = child->next[index ^ 1]; | |
child->next[index ^ 1] = node; | |
// Set the balance factor | |
node->factor = factor_sign - child->factor; // TODO fix this code | |
child->factor = -!child->factor * factor_sign; // TODO fix this code | |
avl_update(node); | |
avl_update(child); | |
// Return height change. | |
return -!child->factor; // TODO fix this code | |
} | |
static avl_type *avl_insert_key(struct avl_node *restrict *restrict branch, const unsigned char *restrict key_data, size_t key_size, avl_type value, int *restrict height_change, int *new) | |
{ | |
struct avl_node *t = *branch; | |
if (t) // the node is occupied | |
{ | |
unsigned char index; | |
avl_type *result; | |
int diff = avl_compare(key_data, key_size, t); | |
if (!diff) // the key is already in the tree | |
{ | |
// WARNING: Tree of complex data type should do something more here. | |
return &t->value; | |
} | |
index = (diff < 0); | |
// Insert the node in the subtree. Exit if the height hasn't changed. | |
result = avl_insert_key(t->next + index, key_data, key_size, value, height_change, new); | |
if (*height_change) | |
{ | |
// Calculate the new balance factor. Rebalance the tree if necessary. | |
t->factor += (int [2]){1, -1}[index]; | |
if ((t->factor < -1) || (t->factor > 1)) | |
*height_change = avl_balance(branch) + 1; | |
else | |
{ | |
*height_change = (t->factor != 0); | |
avl_update(t); | |
} | |
} | |
else if (*new) avl_update(t); | |
return result; | |
} | |
else // the node is vacant | |
{ | |
struct avl_node_mutable *node = malloc(offsetof(struct avl_node, key_data) + key_size); | |
if (!node) | |
return 0; | |
// WARNING: Tree of complex data type should do something more here. | |
node->key_size = key_size; | |
node->value = value; | |
node->factor = 0; | |
node->next[0] = node->next[1] = 0; | |
memcpy(node->key_data, key_data, key_size); | |
*branch = (struct avl_node *)node; | |
*height_change = 1; | |
*new = 1; | |
return &node->value; | |
} | |
} | |
avl_type *avl_insert(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, avl_type value) | |
{ | |
int height_change = 0; | |
int new = 0; | |
avl_type *result = avl_insert_key(&avl->root, key_data, key_size, value, &height_change, &new); | |
if (new) | |
avl->count += 1; | |
return result; | |
} | |
// Calculate the new blance factor after an item removal. Rebalance the tree if necessary. | |
static signed char avl_remove_factor(struct avl_node **branch, size_t index) | |
{ | |
struct avl_node *node = *branch; | |
node->factor -= (int [2]){1, -1}[index]; | |
if ((node->factor < -1) || (node->factor > 1)) | |
return avl_balance(branch); | |
else | |
{ | |
avl_update(node); | |
return -!node->factor; // TODO fix this code | |
} | |
} | |
static int avl_move_closest(struct avl_node **branch, const unsigned char *restrict key_data, size_t key_size, struct avl_node_mutable *restrict position) | |
{ | |
struct avl_node *t = *branch; | |
unsigned char index = (avl_compare(key_data, key_size, t) < 0); | |
struct avl_node **child = t->next + index; | |
if (*child) // the subtree is not empty | |
{ | |
// Remove the node from the subtree. Exit if the height hasn't changed. | |
if (!avl_move_closest(child, key_data, key_size, position)) | |
{ | |
avl_update(t); | |
return 0; | |
} | |
return avl_remove_factor(branch, index); | |
} | |
else // the subtree is empty but the closest node has to be moved | |
{ | |
// Copy node data to its new position. | |
// WARNING: Tree of complex data type should do something more here. | |
position->key_size = t->key_size; | |
memcpy(position->key_data, t->key_data, t->key_size); | |
position->value = t->value; | |
// Remove current node | |
if (t->next[0]) *branch = t->next[0]; | |
else *branch = t->next[1]; | |
// WARNING: Tree of complex data type should do something more here. | |
free(t); | |
return -1; | |
} | |
} | |
static signed char avl_remove_key(struct avl_node **branch, const unsigned char *restrict key_data, size_t key_size, avl_type *value_old, int *restrict changed) | |
{ | |
struct avl_node *t = *branch; | |
int diff = avl_compare(key_data, key_size, t); | |
// If this is the key to be removed | |
if (!diff) | |
{ | |
*changed = 1; | |
if (value_old) | |
*value_old = t->value; | |
if (t->next[0]) | |
{ | |
if (t->next[1]) | |
{ | |
int height_change; | |
unsigned char index = (t->factor < 0); | |
// Replace the current node with the closest by key node in the taller subtree. Exit if the height hasn't changed | |
// WARNING: Tree of complex data type may need to do something more here. | |
height_change = avl_move_closest(t->next + index, key_data, key_size, (struct avl_node_mutable *)t); | |
if (!height_change) | |
{ | |
avl_update(t); | |
return 0; | |
} | |
return avl_remove_factor(branch, index); | |
} | |
else *branch = t->next[0]; | |
} | |
else *branch = t->next[1]; | |
// WARNING: Tree of complex data type should do something more here | |
free(t); | |
return -1; | |
} | |
// Find the node in the subtree where it should be | |
{ | |
unsigned char index = (diff < 0); | |
struct avl_node **child = t->next + index; | |
// If such node doesn't exist, there is nothing to remove or balance. | |
if (!*child) // the subtree is empty | |
return 0; | |
// Remove the node from the subtree. Exit if the height hasn't changed | |
if (!avl_remove_key(child, key_data, key_size, value_old, changed)) | |
{ | |
if (*changed) | |
avl_update(t); | |
return 0; | |
} | |
return avl_remove_factor(branch, index); | |
} | |
} | |
void avl_remove(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, avl_type *value_old) | |
{ | |
if (avl->root) | |
{ | |
int changed = 0; | |
avl_remove_key(&avl->root, key_data, key_size, value_old, &changed); | |
if (changed) | |
avl->count -= 1; | |
} | |
} | |
// Returns whether the iteration has been stopped. | |
static int iterate(struct avl_node *node, const unsigned char *restrict key_data, size_t key_size, int (*callback)(struct avl_node *, void *), void *argument) | |
{ | |
if (!node) | |
return 0; | |
if (key_data) | |
{ | |
int diff = avl_compare(key_data, key_size, node); | |
if (!diff) // key found | |
{ | |
// Iteration starts here. | |
return (*callback)(node, argument) || iterate(node->next[1], 0, 0, callback, argument); | |
} | |
else if (diff > 0) | |
{ | |
// Iteration will start in the left subtree. | |
return iterate(node->next[0], key_data, key_size, callback, argument) || (*callback)(node, argument) || iterate(node->next[1], 0, 0, callback, argument); | |
} | |
else | |
{ | |
// Iteration will start in the right subtree. | |
return iterate(node->next[1], key_data, key_size, callback, argument); | |
} | |
} | |
// Iteration has started. | |
return iterate(node->next[0], key_data, key_size, callback, argument) || (*callback)(node, argument) || iterate(node->next[1], key_data, key_size, callback, argument); | |
} | |
void avl_iterate(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, int (*callback)(struct avl_node *, void *), void *argument) | |
{ | |
iterate(avl->root, key_data, key_size, callback, argument); | |
} | |
static void avl_term_internal(struct avl_node *node) | |
{ | |
if (!node) | |
return; | |
avl_term_internal(node->next[0]); | |
avl_term_internal(node->next[1]); | |
// WARNING: Tree of complex data type should do something more here | |
free(node); | |
} | |
void avl_term(struct avl *avl) | |
{ | |
avl_term_internal(avl->root); | |
} | |
/* tests */ | |
#include <stdarg.h> | |
#include <stddef.h> | |
#include <setjmp.h> | |
#include <cmocka.h> | |
#define NODES_COUNT 1024 | |
static struct avl tree; | |
static avl_type values[NODES_COUNT]; | |
static size_t values_count; | |
/*void print(struct avl_node *t) | |
{ | |
if (t) | |
{ | |
if (t->next[0]) | |
{ | |
printf("["); | |
print(t->next[0]); | |
printf("]"); | |
} | |
printf(" %d ", t->key, t->factor); | |
//printf("%d (%d)", t->key, t->factor); | |
if (t->next[1]) | |
{ | |
printf("["); | |
print(t->next[1]); | |
printf("]"); | |
} | |
} | |
}*/ | |
static unsigned check_height(struct avl_node *t) | |
{ | |
if (t) | |
{ | |
unsigned a = check_height(t->next[0]); | |
unsigned b = check_height(t->next[1]); | |
assert_int_equal(t->factor, (int)a - (int)b); | |
assert_true(t->factor >= -1); | |
assert_true(t->factor <= 1); | |
return ((a >= b) ? a : b) + 1; | |
} | |
return 0; | |
} | |
static void check_sorted(struct avl_node *t) | |
{ | |
if (t->next[0]) | |
{ | |
check_sorted(t->next[0]); | |
assert_true(t->next[0]->value < t->value); | |
} | |
if (t->next[1]) | |
{ | |
assert_true(t->value < t->next[1]->value); | |
check_sorted(t->next[1]); | |
} | |
} | |
#include <netinet/in.h> | |
static void test_avl_insert(void **state) | |
{ | |
for(size_t i = 0; i < NODES_COUNT; i += 1) | |
{ | |
avl_type value = random() % 4096, *result; | |
int key = htonl(value); | |
values[values_count++] = value; | |
result = avl_insert(&tree, (const unsigned char *)&key, sizeof(key), value); | |
assert_non_null(result); | |
assert_int_equal(*result, value); | |
check_height(tree.root); | |
} | |
} | |
static void test_avl_get(void **state) | |
{ | |
for(size_t i = 0; i < NODES_COUNT; i += 1) | |
{ | |
int key = htonl(values[i]); | |
avl_type *value = avl_get(&tree, (const unsigned char *)&key, sizeof(key)); | |
assert_true(value); | |
assert_int_equal(*value, values[i]); | |
} | |
int key = htonl(4096); | |
assert_false(avl_get(&tree, (const unsigned char *)&key, sizeof(key))); | |
} | |
struct it | |
{ | |
avl_type min, max; | |
avl_type last; | |
int done; | |
}; | |
static int callback(struct avl_node *node, void *argument) | |
{ | |
struct it *it = argument; | |
assert_true(node->value > it->last); | |
assert_true(node->value >= it->min); | |
assert_false(it->done); | |
it->last = node->value; | |
it->done = (node->value > it->max); | |
return it->done; | |
} | |
static void test_avl_iterate(void **state) | |
{ | |
struct it it = {.min = 248, .max = 958, .last = -1}; | |
avl_iterate(&tree, (const unsigned char *)&it.min, sizeof(it.min), callback, &it); | |
} | |
static void test_avl_remove(void **state) | |
{ | |
int i; | |
for(i = 0; i < 2048; i += 1) | |
{ | |
int key = htonl(i); | |
avl_type value; | |
avl_remove(&tree, (const unsigned char *)&key, sizeof(key), &value); | |
check_height(tree.root); | |
// TODO some assertion about the value | |
} | |
for(i = 0; i < NODES_COUNT; i += 1) | |
{ | |
int key = htonl(values[i]); | |
if (values[i] < 2048) | |
assert_false(avl_get(&tree, (const unsigned char *)&key, sizeof(key))); | |
else | |
assert_true(avl_get(&tree, (const unsigned char *)&key, sizeof(key))); | |
} | |
} | |
static void test_avl_sorted(void **state) | |
{ | |
check_sorted(tree.root); | |
} | |
int main(int argc, char *argv[]) | |
{ | |
const struct CMUnitTest tests[] = | |
{ | |
cmocka_unit_test(test_avl_insert), | |
cmocka_unit_test(test_avl_get), | |
cmocka_unit_test(test_avl_iterate), | |
cmocka_unit_test(test_avl_remove), | |
cmocka_unit_test(test_avl_sorted), | |
}; | |
int status; | |
srandom(1); | |
tree = (struct avl){0}; | |
status = cmocka_run_group_tests(tests, 0, 0); | |
avl_term(&tree); | |
return status; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment