Last active
January 7, 2016 19:44
-
-
Save LuizZak/55f64be3df6c874b8c85 to your computer and use it in GitHub Desktop.
Simple self-balancing binary tree data structure written with enums in Swift 2.0
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
/// A protocol to be implemented by self-contained, recursive binary-tree data structures | |
protocol BinaryTreeType: SequenceType, CustomStringConvertible { | |
/// The type of element stored in this binary tree. This value must implement the Comparable protocol | |
typealias Element: Comparable | |
/// Gets the value in this binary tree element | |
var value: Element? { get } | |
/// Gets the left sub-tree for this recursive binary tree | |
var left: Self? { get } | |
/// Gets the right sub-tree for this recursive binary tree | |
var right: Self? { get } | |
/// Adds a value to this binary tree | |
mutating func addValue(value: Element) | |
/// Removes a value from this binary tree | |
mutating func removeValue(value: Element) | |
/// Returns whether a given element is contained within this binary tree | |
func containsValue(value: Element) -> Bool | |
/// Traverses this binary tree, visiting each element with the given visitor closure | |
func traverse(visitor: (Element) -> ()) | |
/// Traverses this binary tree using a breadth-first visiting pattern | |
func traverseBreadthFirst(visitor: (Element) -> ()) | |
/// Returns a recreation of this binary tree, in the most well-balance format possible | |
func recreateBalanced() -> Self | |
/// Rotates this binary tree left-wise | |
mutating func rotateLeft() | |
/// Rotates this binary tree right-wise | |
mutating func rotateRight() | |
/// Returns the maximum value contained on this binary tree | |
func maxValue() -> Element? | |
/// Returns the minimum value contained on this binary tree | |
func minValue() -> Element? | |
/// Returns the depth of this binary tree, starting from 1 if this tree is not a leaf | |
func depth() -> Int | |
} | |
/// Extended binary tree which contains sub-tree balancing functionality | |
protocol BalancedBinaryTreeType: BinaryTreeType { | |
/// Balances this binary tree | |
mutating func balance() | |
/// Gets the current balance index for this binary tree | |
func getBalance() -> Int | |
} | |
/// Default implementation for some methods for the binary tree type | |
extension BinaryTreeType { | |
typealias GeneratorType = AnyGenerator<Element> | |
func containsValue(value: Element) -> Bool { | |
if(self.value == value) { | |
return true | |
} | |
if let left = self.left where self.value > value { | |
return left.containsValue(value) | |
} | |
if let right = self.right where self.value < value { | |
return right.containsValue(value) | |
} | |
return false | |
} | |
func traverse(visitor: (Element) -> ()) { | |
self.left?.traverse(visitor) | |
if let val = self.value { | |
visitor(val) | |
} | |
self.right?.traverse(visitor) | |
} | |
func traverseBreadthFirst(visitor: (Element) -> ()) { | |
var queue: [Self] = [self] | |
while(queue.count > 0) { | |
let next = queue.removeFirst() | |
if let value = next.value { | |
visitor(value) | |
} | |
if let left = next.left { | |
queue.append(left) | |
} | |
if let right = next.right { | |
queue.append(right) | |
} | |
} | |
} | |
func maxValue() -> Element? { | |
if let right = right { | |
return right.maxValue() | |
} | |
return self.value | |
} | |
func minValue() -> Element? { | |
if let left = left { | |
return left.maxValue() | |
} | |
return self.value | |
} | |
func depth() -> Int { | |
return 1 + max(self.left?.depth() ?? 0, self.right?.depth() ?? 0) | |
} | |
func generate() -> AnyGenerator<Element> { | |
var stack: [Self] = [] | |
var current: Self? = self | |
// Iterative in-order traversal | |
return anyGenerator { | |
while (true) { | |
if(current == nil) { | |
if(stack.isEmpty) { | |
return nil | |
} else { | |
current = stack.removeLast() | |
let value = current!.value | |
current = current?.right | |
return value | |
} | |
} else { | |
stack.append(current!) | |
current = current?.left | |
} | |
} | |
} | |
} | |
} | |
extension BinaryTreeType { | |
var description: String { | |
return "BinaryTree(\(self.value))" | |
} | |
} | |
/// Default implementation for some methods for the balanced binary tree type | |
extension BalancedBinaryTreeType { | |
func getBalance() -> Int { | |
return (self.left?.depth() ?? 0) - (self.right?.depth() ?? 0) | |
} | |
} | |
/// An enum-based binary tree implementation | |
enum Node<T: Comparable>: BalancedBinaryTreeType { | |
typealias Element = T | |
case None | |
indirect case Tree(left: Node<T>, value: T, right: Node<T>) | |
var value: T? { | |
switch(self) { | |
case .None: | |
return nil | |
case .Tree(_, let value, _): | |
return value | |
} | |
} | |
var left: Node<T>? { | |
switch(self) { | |
case .None: | |
return nil | |
case .Tree(let left, _, _): | |
return left | |
} | |
} | |
var right: Node<T>? { | |
switch(self) { | |
case .None: | |
return nil | |
case .Tree(_, _, let right): | |
return right | |
} | |
} | |
mutating func addValue(value: T) { | |
switch(self) { | |
case .None: | |
self = .Tree(left: .None, value: value, right: .None) | |
case .Tree(_, let nodeValue, _) where nodeValue == value: | |
break | |
case .Tree(var left, let nodeValue, var right): | |
if(value < nodeValue) { | |
left.addValue(value) | |
} else { | |
right.addValue(value) | |
} | |
self = .Tree(left: left, value: nodeValue, right: right) | |
} | |
self.balance() | |
} | |
func minValue() -> T? { | |
switch(self) { | |
case .None: | |
return nil | |
case .Tree(.None, let v, _): | |
return v | |
case .Tree(let left, _, _): | |
return left.minValue()! | |
} | |
} | |
func maxValue() -> T? { | |
switch(self) { | |
case .None: | |
return nil | |
case .Tree(_, let v, .None): | |
return v | |
case .Tree(_, _, let right): | |
return right.maxValue()! | |
} | |
} | |
mutating func removeValue(value: T) { | |
switch(self) { | |
// Value lies in left sub-tree | |
case .Tree(var left, let v, let right) where v > value: | |
left.removeValue(value) | |
self = .Tree(left: left, value: v, right: right) | |
// Value is a leaf - return .None | |
case .Tree(.None, let v, .None) where v == value: | |
self = .None | |
// Value is not a leaf, but has a single child - lift child as new node | |
case .Tree(.None, let v, let right) where v == value: | |
self = right | |
case .Tree(let left, let v, .None) where v == value: | |
self = left | |
// Value is a trunk - remove lowest value from right, and return a node which value's the lowest right child value | |
case .Tree(let left, let v, var right) where v == value: | |
right.removeValue(right.minValue()!) | |
self = .Tree(left: left, value: right.minValue()!, right: right) | |
// Value lies in right sub-tree | |
case .Tree(let left, let v, var right) where v < value: | |
right.removeValue(value) | |
self = .Tree(left: left, value: v, right: right) | |
default: | |
break | |
} | |
} | |
func traverse(visitor: (T) -> ()) { | |
switch(self) { | |
case .None: | |
break | |
case .Tree(let left, let val, let right): | |
left.traverse(visitor) | |
visitor(val) | |
right.traverse(visitor) | |
break | |
} | |
} | |
func recreateBalanced() -> Node<T> { | |
let list = self.map { $0 } | |
var node = Node<T>.None | |
var queue: [(start: Int, end: Int)] = [] | |
queue.append((0, list.count - 1)) | |
while(queue.count > 0) { | |
let next = queue.removeFirst() | |
let start = next.start | |
let end = next.end | |
if(start > end) { | |
return node | |
} | |
if(end == start + 1) { | |
node.addValue(list[start]) | |
node.addValue(list[end]) | |
continue | |
} | |
if(end == start) { | |
node.addValue(list[start]) | |
continue | |
} | |
let mid = (start + end) / 2 | |
node.addValue(list[mid]) | |
queue.append((mid + 1, end)) | |
queue.append((start, mid - 1)) | |
} | |
return node | |
} | |
func getBalance() -> Int { | |
switch(self) { | |
case .None: | |
return 0 | |
case .Tree(let left, _, let right): | |
return left.depth() - right.depth() | |
} | |
} | |
mutating func balance() { | |
switch(self) { | |
case .None: | |
break | |
case .Tree(var left, let value, var right): | |
let b = getBalance() | |
if(b > 1) { | |
if left.getBalance() < 0 { | |
left.rotateLeft() | |
self = .Tree(left: left, value: value, right: right) | |
} | |
self.rotateRight() | |
} else if(b < -1) { | |
if right.getBalance() > 0 { | |
right.rotateRight() | |
self = .Tree(left: left, value: value, right: right) | |
} | |
self.rotateLeft() | |
} | |
} | |
} | |
mutating func rotateLeft() { | |
switch(self) { | |
case .Tree(let left, let value, .Tree(let rLeft, let rValue, let rRight)): | |
let newLeft = Node<T>.Tree(left: left, value: value, right: rLeft) | |
self = Node<T>.Tree(left: newLeft, value: rValue, right: rRight) | |
default: | |
break | |
} | |
} | |
mutating func rotateRight() { | |
switch(self) { | |
case .Tree(.Tree(let lLeft, let lValue, let lRight), let value, let right): | |
let newRight = Node<T>.Tree(left: lRight, value: value, right: right) | |
self = Node<T>.Tree(left: lLeft, value: lValue, right: newRight) | |
default: | |
break | |
} | |
} | |
func depth() -> Int { | |
switch(self) { | |
case .None: | |
return 0 | |
case .Tree(let left, _, let right): | |
return 1 + max(left.depth(), right.depth()) | |
} | |
} | |
func generate() -> AnyGenerator<T> { | |
var stack: [Node<T>] = [] | |
var current = self | |
// Iterative in-order traversal | |
return anyGenerator { | |
while (true) { | |
switch(current) { | |
case .None: | |
if(stack.isEmpty) { | |
return nil | |
} else { | |
current = stack.removeLast() | |
switch(current) { | |
case .None: | |
break | |
case .Tree(_, let value, let right): | |
current = right | |
return value | |
} | |
} | |
break | |
case .Tree(let left, _, _): | |
stack.append(current) | |
current = left | |
break | |
} | |
} | |
} | |
} | |
} | |
// Credits: Nate Cook (natecook1000) on Gist | |
func shuffle<C: MutableCollectionType where C.Index == Int>(inout list: C) { | |
let c = list.count | |
for i in 0..<(c - 1) { | |
let j = Int(arc4random_uniform(UInt32(c - i))) + i | |
if(i != j) { | |
swap(&list[i], &list[j]) | |
} | |
} | |
} | |
func testTree<T: BinaryTreeType where T.Element == Int>(var tree: T) { | |
assert(tree.value == nil, "Input tree must be empty!") | |
var newTree = tree.addValue(10) | |
assert(newTree.containsValue(10)) | |
newTree.addValue(5) | |
newTree.addValue(15) | |
assert(newTree.minValue() == 5) | |
assert(newTree.maxValue() == 15) | |
newTree.removeValue(5) | |
assert(!newTree.containsValue(5)) | |
assert(newTree.minValue() == 10) | |
// Clear tree | |
while let maxValue = newTree.maxValue() { | |
newTree.removeValue(maxValue) | |
} | |
// Add, in random order, numbers from 0-30 | |
var numbers = [Int](0...30) | |
shuffle(&numbers) | |
for n in numbers { | |
newTree.addValue(n) | |
} | |
// Assert that all numbers are contained within the new tree | |
assert(numbers.map(newTree.containsValue).reduce(true, combine: { $0 && $1 })) | |
} | |
print("==== Node<T> ====\n") | |
var root: Node<Int> = .None | |
root.addValue(3) | |
root.addValue(2) | |
root.addValue(4) | |
root.addValue(1) | |
root.addValue(5) | |
print("Root depth: \(root.depth())") | |
print("Tree min value: \(root.minValue()) max value: \(root.maxValue())") | |
root.map { print($0) } | |
print("") | |
root.traverse { print($0) } | |
print("Breadth-first traversal:") | |
root.traverseBreadthFirst { print($0) } | |
print("Root contains 3: \(root.containsValue(3))") | |
testTree(Node<Int>.None) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment