Skip to content

Instantly share code, notes, and snippets.

@gofer
Last active August 15, 2018 16:33
Show Gist options
  • Save gofer/d578228cb4bca009c51f to your computer and use it in GitHub Desktop.
Save gofer/d578228cb4bca009c51f to your computer and use it in GitHub Desktop.
AVL Tree
(* MLton用 ORD_KEY signature *)
(*
signature ORD_KEY =
sig
type ord_key
val compare : (ord_key * ord_key) -> order;
end;
*)
(* 2分木のシグネチャ *)
signature BIN_TREE =
sig
exception EmptyTree;
type item;
datatype tree = Empty | Node of item * tree * tree;
val empty : tree;
val isEmpty : tree -> bool;
val member : (tree * item) -> bool;
val height : tree -> int;
val insert : (tree * item) -> tree;
val delete : (tree * item) -> tree;
end;
(* AVL木の実装 *)
functor AVLTreeFunctor
(
OrderedKey : ORD_KEY
) :> BIN_TREE
where type item = OrderedKey.ord_key
=
struct
exception EmptyTree;
type item = OrderedKey.ord_key;
datatype tree = Empty | Node of item * tree * tree;
val empty = Empty;
fun isEmpty Empty = true
| isEmpty _ = false;
fun member (Empty, _) = false
| member (Node(value, left, right), query) =
case (OrderedKey.compare(query, value))
of EQUAL => true
| LESS => member (left, query)
| GREATER => member (right, query);
fun height Empty = 0
| height (Node(_, left ,right)) = let
fun max (x, y) = if x > y then x else y
in 1 + max(height left, height right) end;
fun balance Empty = Empty
| balance (node as Node(value, left, right)) =
let
fun bias Empty = 0
| bias (Node(_, left, right)) = (height left) - (height right);
fun rotate_left Empty = Empty
| rotate_left (node as Node(value, left, Empty)) = node
| rotate_left (Node(value, left, Node(right_value, right_left, right_right))) =
Node(right_value, Node(value, left, right_left), right_right);
fun rotate_right Empty = Empty
| rotate_right (node as Node(value, Empty, right)) = node
| rotate_right (Node(value, Node(left_value, left_left, left_right), right)) =
Node(left_value, left_left, Node(value, left_right, right));
fun rotate_left_right Empty = Empty
| rotate_left_right (node as Node(value, Empty, right)) = node
| rotate_left_right (Node(value, left, right)) =
rotate_right (Node(value, rotate_left left, right));
fun rotate_right_left Empty = Empty
| rotate_right_left (node as Node(value, left, Empty)) = node
| rotate_right_left (Node(value, left, right)) =
rotate_left (Node(value, left, rotate_right right));
in
case(bias node)
of 2 => if (bias left ) = 1 then rotate_right node else rotate_left_right node
| ~2 => if (bias right) = ~1 then rotate_left node else rotate_right_left node
| _ => node
end;
fun insert (Empty, query) = Node(query, Empty, Empty)
| insert (Node(value, left, right), query) =
let
val node =
case (OrderedKey.compare(query, value))
of EQUAL => Node(query, left, right)
| LESS => Node(value, insert (left, query), right)
| GREATER => Node(value, left, insert (right, query));
in balance node end;
fun delete (Empty, query) = Empty
| delete ((node as Node(value, Empty, right)), query) =
let
val node =
case (OrderedKey.compare(query, value))
of EQUAL => right
| LESS => node
| GREATER => Node(value, Empty, delete(right, query));
in balance node end
| delete ((node as Node(value, left, Empty)), query) =
let
val node =
case (OrderedKey.compare(query, value))
of EQUAL => left
| LESS => Node(value, delete(left, query), Empty)
| GREATER => node;
in balance node end
| delete ((node as Node(value, left, right)), query) =
let
val node =
case (OrderedKey.compare(query, value))
of EQUAL =>
(* Delete right side minimum *)
let
fun search_min Empty = raise EmptyTree
| search_min (Node(value, Empty, right)) = value
| search_min (Node(value, left, right)) = search_min left;
fun delete_min Empty = Empty
| delete_min (Node(value, Empty, right)) = right
| delete_min (Node(value, left, right)) = Node(value, delete_min left, right);
in Node(search_min right, left, delete_min right) end
(* Delete left side maximum *)
(*
let
fun search_max Empty = raise EmptyTree
| search_max (Node(value, left, Empty)) = value
| search_max (Node(value, left, right)) = search_max right;
fun delete_max Empty = Empty
| delete_max (Node(value, left, Empty)) = left
| delete_max (Node(value, left, right)) = Node(value, left, delete_max right);
in Node(search_max left, delete_max left, right) end
*)
| LESS => Node(value, delete(left, query), right)
| GREATER => Node(value, left, delete(right, query));
in balance node end;
end;
(* 比較付き整数 *)
structure OrderedInt :> ORD_KEY
where type ord_key = Int.int
= struct
type ord_key = Int.int;
fun compare (lhs, rhs) =
if lhs = rhs
then EQUAL
else
if lhs < rhs then LESS else GREATER;
end;
(* AVL木(整数) *)
structure IntAVLTree = AVLTreeFunctor
(
OrderedInt
);
(* デバッグ用ユーティリティ *)
fun to_string IntAVLTree.Empty = "_"
| to_string (IntAVLTree.Node(value, left, right)) = "(" ^ (String.concatWith ", " [Int.toString value, to_string left, to_string right]) ^ ")";
fun tree_to_dot tree = let
fun to_string IntAVLTree.Empty = Option.NONE
| to_string (IntAVLTree.Node(value, _, _)) = Option.SOME (Int.toString value);
fun build_arrow (src, dst) = let
fun arrow (src, dst) = " " ^ src ^ " -> " ^ dst ^ ";\n";
in
case (dst)
of (Option.SOME dst) => arrow (src, dst)
| Option.NONE => ""
end;
fun dot IntAVLTree.Empty = ""
| dot (IntAVLTree.Node(src, left, right)) = let
val lhs = build_arrow (Int.toString src, to_string left);
val rhs = build_arrow (Int.toString src, to_string right);
in (lhs ^ rhs) ^ (dot left) ^ (dot right) end;
in "digraph test {\n" ^ (dot tree) ^ "}" end;
(* 実験 *)
(*
val tree = IntAVLTree.empty;
val tree = IntAVLTree.insert (tree, 1);
val tree = IntAVLTree.insert (tree, 2);
val tree = IntAVLTree.insert (tree, 3);
val tree = IntAVLTree.insert (tree, 4);
val tree = IntAVLTree.insert (tree, 5);
val tree = IntAVLTree.insert (tree, 6);
val tree = IntAVLTree.insert (tree, 7);
val tree = IntAVLTree.insert (tree, 8);
val tree = IntAVLTree.insert (tree, 9);
val tree = IntAVLTree.insert (tree, 10);
val tree = IntAVLTree.insert (tree, 11);
val tree = IntAVLTree.insert (tree, 12);
val tree = IntAVLTree.insert (tree, 13);
val tree = IntAVLTree.insert (tree, 14);
val tree = IntAVLTree.insert (tree, 15);
val tree = IntAVLTree.delete (tree, 4);
val tree = IntAVLTree.delete (tree, 10);
val () = print ((tree_to_dot tree) ^ "\n");
*)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment