/* * avltree.h * * Copyright (C) 2018 Aleksandar Andrejevic * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ #ifndef __MONOLITHIUM_AVLTREE_H__ #define __MONOLITHIUM_AVLTREE_H__ #include "defs.h" #define AVL_TREE_INIT(t, s, n, k, c) avl_tree_init((t), (ptrdiff_t)&((s*)NULL)->k - (ptrdiff_t)&((s*)NULL)->n, sizeof(((s*)NULL)->k), c) typedef int (*avl_compare_proc_t)(const void *key1, const void *key2); typedef struct avl_node { struct avl_node *parent; struct avl_node *left; struct avl_node *right; struct avl_node *next_equal; struct avl_node *prev_equal; int balance; } avl_node_t; typedef struct avl_tree { avl_node_t *root; ptrdiff_t key_offset; size_t key_size; avl_compare_proc_t compare; } avl_tree_t; static inline void avl_tree_init(avl_tree_t *tree, ptrdiff_t key_offset, size_t key_size, avl_compare_proc_t compare) { tree->root = NULL; tree->key_offset = key_offset; tree->key_size = key_size; tree->compare = compare; } static inline void *avl_get_keyptr(const avl_tree_t *tree, const avl_node_t *node) { return (void*)((ptrdiff_t)node + tree->key_offset); } static inline avl_node_t *avl_tree_lookup(const avl_tree_t *tree, const void *key) { avl_node_t *node = tree->root; while (node) { const void *node_key = avl_get_keyptr(tree, node); int comparison = tree->compare(key, node_key); if (comparison == 0) return node; else if (comparison < 0) node = node->left; else node = node->right; } return NULL; } static inline avl_node_t *avl_tree_lower_bound(const avl_tree_t *tree, const void *key) { avl_node_t *node = tree->root; while (node && tree->compare(avl_get_keyptr(tree, node), key) > 0) node = node->left; if (!node) return NULL; while (node->right && tree->compare(avl_get_keyptr(tree, node->right), key) <= 0) node = node->right; return node; } static inline avl_node_t *avl_tree_upper_bound(const avl_tree_t *tree, const void *key) { avl_node_t *node = tree->root; while (node && tree->compare(avl_get_keyptr(tree, node), key) < 0) node = node->right; if (!node) return NULL; while (node->left && tree->compare(avl_get_keyptr(tree, node->left), key) >= 0) node = node->left; return node; } static inline avl_node_t *avl_get_next_node(const avl_node_t *node) { while (node->prev_equal) node = node->prev_equal; if (node->right) { node = node->right; while (node->left) node = node->left; } else { while (node->parent && node->parent->right == node) node = node->parent; node = node->parent; } return (avl_node_t*)node; } static inline avl_node_t *avl_get_previous_node(const avl_node_t *node) { while (node->prev_equal) node = node->prev_equal; if (node->left) { node = node->left; while (node->right) node = node->right; } else { while (node->parent && node->parent->left == node) node = node->parent; node = node->parent; } return (avl_node_t*)node; } static inline avl_node_t *avl_rotate_left(avl_tree_t *tree, avl_node_t *root) { avl_node_t *pivot = root->right; root->right = pivot->left; if (root->right) root->right->parent = root; pivot->parent = root->parent; pivot->left = root; root->parent = pivot; if (pivot->parent) { if (pivot->parent->left == root) pivot->parent->left = pivot; else if (pivot->parent->right == root) pivot->parent->right = pivot; } else { tree->root = pivot; } root->balance -= pivot->balance > 0 ? pivot->balance + 1 : 1; pivot->balance += root->balance < 0 ? root->balance - 1 : -1; return pivot; } static inline avl_node_t *avl_rotate_right(avl_tree_t *tree, avl_node_t *root) { avl_node_t *pivot = root->left; root->left = pivot->right; if (root->left) root->left->parent = root; pivot->parent = root->parent; pivot->right = root; root->parent = pivot; if (pivot->parent) { if (pivot->parent->left == root) pivot->parent->left = pivot; else if (pivot->parent->right == root) pivot->parent->right = pivot; } else { tree->root = pivot; } root->balance -= pivot->balance < 0 ? pivot->balance - 1 : -1; pivot->balance += root->balance > 0 ? root->balance + 1 : 1; return pivot; } static void avl_tree_insert(avl_tree_t *tree, avl_node_t *node) { node->left = node->right = node->parent = node->next_equal = node->prev_equal = NULL; node->balance = 0; if (!tree->root) { tree->root = node; return; } avl_node_t *current = tree->root; const void *node_key = avl_get_keyptr(tree, node); while (TRUE) { const void *key = avl_get_keyptr(tree, current); int comparison = tree->compare(node_key, key); if (comparison == 0) { while (current->next_equal) current = current->next_equal; current->next_equal = node; node->prev_equal = current; return; } else if (comparison < 0) { if (!current->left) { node->parent = current; current->left = node; break; } else { current = current->left; } } else { if (!current->right) { node->parent = current; current->right = node; break; } else { current = current->right; } } } while (current) { if (node == current->left) current->balance--; else current->balance++; if (current->balance == 0) break; if (current->balance < -1) { if (node->balance > 0) avl_rotate_left(tree, current->left); current = avl_rotate_right(tree, current); break; } else if (current->balance > 1) { if (node->balance < 0) avl_rotate_right(tree, current->right); current = avl_rotate_left(tree, current); break; } node = current; current = current->parent; } } static void avl_tree_remove(avl_tree_t *tree, avl_node_t *node) { if (node->prev_equal) { node->prev_equal->next_equal = node->next_equal; if (node->next_equal) node->next_equal->prev_equal = node->prev_equal; node->next_equal = node->prev_equal = NULL; return; } else if (node->next_equal) { node->next_equal->parent = node->parent; node->next_equal->left = node->left; node->next_equal->right = node->right; node->next_equal->prev_equal = NULL; if (node->parent) { if (node->parent->left == node) node->parent->left = node->next_equal; else node->parent->right = node->next_equal; } else { tree->root = node->next_equal; } if (node->left) node->left->parent = node->next_equal; if (node->right) node->right->parent = node->next_equal; node->parent = node->left = node->right = node->next_equal = NULL; node->balance = 0; return; } if (node->left && node->right) { avl_node_t *replacement = node->right; if (replacement->left) { while (replacement->left) replacement = replacement->left; avl_node_t *temp_parent = replacement->parent; avl_node_t *temp_right = replacement->right; int temp_balance = replacement->balance; replacement->parent = node->parent; replacement->left = node->left; replacement->right = node->right; replacement->balance = node->balance; if (replacement->parent) { if (replacement->parent->left == node) replacement->parent->left = replacement; else replacement->parent->right = replacement; } else { tree->root = replacement; } if (replacement->left) replacement->left->parent = replacement; if (replacement->right) replacement->right->parent = replacement; node->parent = temp_parent; node->left = NULL; node->right = temp_right; node->balance = temp_balance; if (node->parent->left == replacement) node->parent->left = node; else node->parent->right = node; if (node->right) node->right->parent = node; } else { avl_node_t *temp_right = replacement->right; int temp_balance = replacement->balance; replacement->parent = node->parent; replacement->left = node->left; replacement->right = node; replacement->balance = node->balance; if (replacement->parent) { if (replacement->parent->left == node) replacement->parent->left = replacement; else replacement->parent->right = replacement; } else { tree->root = replacement; } if (replacement->left) replacement->left->parent = replacement; node->parent = replacement; node->left = NULL; node->right = temp_right; node->balance = temp_balance; if (node->right) node->right->parent = node; } } avl_node_t *current = node->parent; bool_t left_child; if (current) { left_child = current->left == node; if (left_child) { current->left = node->left ? node->left : node->right; if (current->left) current->left->parent = current; } else { current->right = node->left ? node->left : node->right; if (current->right) current->right->parent = current; } } else { tree->root = node->left ? node->left : node->right; if (tree->root) tree->root->parent = NULL; } node->parent = node->left = node->right = NULL; node->balance = 0; while (current) { if (left_child) current->balance++; else current->balance--; if (current->balance == 1 || current->balance == -1) break; if (current->balance < -1) { int balance = current->left->balance; if (balance > 0) avl_rotate_left(tree, current->left); current = avl_rotate_right(tree, current); if (balance == 0) break; } else if (current->balance > 1) { int balance = current->right->balance; if (balance < 0) avl_rotate_right(tree, current->right); current = avl_rotate_left(tree, current); if (balance == 0) break; } node = current; current = current->parent; if (current) left_child = current->left == node; } } static inline void avl_tree_change_key(avl_tree_t *tree, avl_node_t *node, const void *new_key) { avl_tree_remove(tree, node); __builtin_memcpy(avl_get_keyptr(tree, node), new_key, tree->key_size); avl_tree_insert(tree, node); } #endif