Skip to content

Instantly share code, notes, and snippets.

@LuizZak
Last active January 7, 2016 19:44
Show Gist options
  • Save LuizZak/55f64be3df6c874b8c85 to your computer and use it in GitHub Desktop.
Save LuizZak/55f64be3df6c874b8c85 to your computer and use it in GitHub Desktop.
Simple self-balancing binary tree data structure written with enums in Swift 2.0
import Foundation
/// A protocol to be implemented by self-contained, recursive binary-tree data structures
protocol BinaryTreeType: SequenceType, CustomStringConvertible {
/// The type of element stored in this binary tree. This value must implement the Comparable protocol
typealias Element: Comparable
/// Gets the value in this binary tree element
var value: Element? { get }
/// Gets the left sub-tree for this recursive binary tree
var left: Self? { get }
/// Gets the right sub-tree for this recursive binary tree
var right: Self? { get }
/// Adds a value to this binary tree
mutating func addValue(value: Element)
/// Removes a value from this binary tree
mutating func removeValue(value: Element)
/// Returns whether a given element is contained within this binary tree
func containsValue(value: Element) -> Bool
/// Traverses this binary tree, visiting each element with the given visitor closure
func traverse(visitor: (Element) -> ())
/// Traverses this binary tree using a breadth-first visiting pattern
func traverseBreadthFirst(visitor: (Element) -> ())
/// Returns a recreation of this binary tree, in the most well-balance format possible
func recreateBalanced() -> Self
/// Rotates this binary tree left-wise
mutating func rotateLeft()
/// Rotates this binary tree right-wise
mutating func rotateRight()
/// Returns the maximum value contained on this binary tree
func maxValue() -> Element?
/// Returns the minimum value contained on this binary tree
func minValue() -> Element?
/// Returns the depth of this binary tree, starting from 1 if this tree is not a leaf
func depth() -> Int
}
/// Extended binary tree which contains sub-tree balancing functionality
protocol BalancedBinaryTreeType: BinaryTreeType {
/// Balances this binary tree
mutating func balance()
/// Gets the current balance index for this binary tree
func getBalance() -> Int
}
/// Default implementation for some methods for the binary tree type
extension BinaryTreeType {
typealias GeneratorType = AnyGenerator<Element>
func containsValue(value: Element) -> Bool {
if(self.value == value) {
return true
}
if let left = self.left where self.value > value {
return left.containsValue(value)
}
if let right = self.right where self.value < value {
return right.containsValue(value)
}
return false
}
func traverse(visitor: (Element) -> ()) {
self.left?.traverse(visitor)
if let val = self.value {
visitor(val)
}
self.right?.traverse(visitor)
}
func traverseBreadthFirst(visitor: (Element) -> ()) {
var queue: [Self] = [self]
while(queue.count > 0) {
let next = queue.removeFirst()
if let value = next.value {
visitor(value)
}
if let left = next.left {
queue.append(left)
}
if let right = next.right {
queue.append(right)
}
}
}
func maxValue() -> Element? {
if let right = right {
return right.maxValue()
}
return self.value
}
func minValue() -> Element? {
if let left = left {
return left.maxValue()
}
return self.value
}
func depth() -> Int {
return 1 + max(self.left?.depth() ?? 0, self.right?.depth() ?? 0)
}
func generate() -> AnyGenerator<Element> {
var stack: [Self] = []
var current: Self? = self
// Iterative in-order traversal
return anyGenerator {
while (true) {
if(current == nil) {
if(stack.isEmpty) {
return nil
} else {
current = stack.removeLast()
let value = current!.value
current = current?.right
return value
}
} else {
stack.append(current!)
current = current?.left
}
}
}
}
}
extension BinaryTreeType {
var description: String {
return "BinaryTree(\(self.value))"
}
}
/// Default implementation for some methods for the balanced binary tree type
extension BalancedBinaryTreeType {
func getBalance() -> Int {
return (self.left?.depth() ?? 0) - (self.right?.depth() ?? 0)
}
}
/// An enum-based binary tree implementation
enum Node<T: Comparable>: BalancedBinaryTreeType {
typealias Element = T
case None
indirect case Tree(left: Node<T>, value: T, right: Node<T>)
var value: T? {
switch(self) {
case .None:
return nil
case .Tree(_, let value, _):
return value
}
}
var left: Node<T>? {
switch(self) {
case .None:
return nil
case .Tree(let left, _, _):
return left
}
}
var right: Node<T>? {
switch(self) {
case .None:
return nil
case .Tree(_, _, let right):
return right
}
}
mutating func addValue(value: T) {
switch(self) {
case .None:
self = .Tree(left: .None, value: value, right: .None)
case .Tree(_, let nodeValue, _) where nodeValue == value:
break
case .Tree(var left, let nodeValue, var right):
if(value < nodeValue) {
left.addValue(value)
} else {
right.addValue(value)
}
self = .Tree(left: left, value: nodeValue, right: right)
}
self.balance()
}
func minValue() -> T? {
switch(self) {
case .None:
return nil
case .Tree(.None, let v, _):
return v
case .Tree(let left, _, _):
return left.minValue()!
}
}
func maxValue() -> T? {
switch(self) {
case .None:
return nil
case .Tree(_, let v, .None):
return v
case .Tree(_, _, let right):
return right.maxValue()!
}
}
mutating func removeValue(value: T) {
switch(self) {
// Value lies in left sub-tree
case .Tree(var left, let v, let right) where v > value:
left.removeValue(value)
self = .Tree(left: left, value: v, right: right)
// Value is a leaf - return .None
case .Tree(.None, let v, .None) where v == value:
self = .None
// Value is not a leaf, but has a single child - lift child as new node
case .Tree(.None, let v, let right) where v == value:
self = right
case .Tree(let left, let v, .None) where v == value:
self = left
// Value is a trunk - remove lowest value from right, and return a node which value's the lowest right child value
case .Tree(let left, let v, var right) where v == value:
right.removeValue(right.minValue()!)
self = .Tree(left: left, value: right.minValue()!, right: right)
// Value lies in right sub-tree
case .Tree(let left, let v, var right) where v < value:
right.removeValue(value)
self = .Tree(left: left, value: v, right: right)
default:
break
}
}
func traverse(visitor: (T) -> ()) {
switch(self) {
case .None:
break
case .Tree(let left, let val, let right):
left.traverse(visitor)
visitor(val)
right.traverse(visitor)
break
}
}
func recreateBalanced() -> Node<T> {
let list = self.map { $0 }
var node = Node<T>.None
var queue: [(start: Int, end: Int)] = []
queue.append((0, list.count - 1))
while(queue.count > 0) {
let next = queue.removeFirst()
let start = next.start
let end = next.end
if(start > end) {
return node
}
if(end == start + 1) {
node.addValue(list[start])
node.addValue(list[end])
continue
}
if(end == start) {
node.addValue(list[start])
continue
}
let mid = (start + end) / 2
node.addValue(list[mid])
queue.append((mid + 1, end))
queue.append((start, mid - 1))
}
return node
}
func getBalance() -> Int {
switch(self) {
case .None:
return 0
case .Tree(let left, _, let right):
return left.depth() - right.depth()
}
}
mutating func balance() {
switch(self) {
case .None:
break
case .Tree(var left, let value, var right):
let b = getBalance()
if(b > 1) {
if left.getBalance() < 0 {
left.rotateLeft()
self = .Tree(left: left, value: value, right: right)
}
self.rotateRight()
} else if(b < -1) {
if right.getBalance() > 0 {
right.rotateRight()
self = .Tree(left: left, value: value, right: right)
}
self.rotateLeft()
}
}
}
mutating func rotateLeft() {
switch(self) {
case .Tree(let left, let value, .Tree(let rLeft, let rValue, let rRight)):
let newLeft = Node<T>.Tree(left: left, value: value, right: rLeft)
self = Node<T>.Tree(left: newLeft, value: rValue, right: rRight)
default:
break
}
}
mutating func rotateRight() {
switch(self) {
case .Tree(.Tree(let lLeft, let lValue, let lRight), let value, let right):
let newRight = Node<T>.Tree(left: lRight, value: value, right: right)
self = Node<T>.Tree(left: lLeft, value: lValue, right: newRight)
default:
break
}
}
func depth() -> Int {
switch(self) {
case .None:
return 0
case .Tree(let left, _, let right):
return 1 + max(left.depth(), right.depth())
}
}
func generate() -> AnyGenerator<T> {
var stack: [Node<T>] = []
var current = self
// Iterative in-order traversal
return anyGenerator {
while (true) {
switch(current) {
case .None:
if(stack.isEmpty) {
return nil
} else {
current = stack.removeLast()
switch(current) {
case .None:
break
case .Tree(_, let value, let right):
current = right
return value
}
}
break
case .Tree(let left, _, _):
stack.append(current)
current = left
break
}
}
}
}
}
// Credits: Nate Cook (natecook1000) on Gist
func shuffle<C: MutableCollectionType where C.Index == Int>(inout list: C) {
let c = list.count
for i in 0..<(c - 1) {
let j = Int(arc4random_uniform(UInt32(c - i))) + i
if(i != j) {
swap(&list[i], &list[j])
}
}
}
func testTree<T: BinaryTreeType where T.Element == Int>(var tree: T) {
assert(tree.value == nil, "Input tree must be empty!")
var newTree = tree.addValue(10)
assert(newTree.containsValue(10))
newTree.addValue(5)
newTree.addValue(15)
assert(newTree.minValue() == 5)
assert(newTree.maxValue() == 15)
newTree.removeValue(5)
assert(!newTree.containsValue(5))
assert(newTree.minValue() == 10)
// Clear tree
while let maxValue = newTree.maxValue() {
newTree.removeValue(maxValue)
}
// Add, in random order, numbers from 0-30
var numbers = [Int](0...30)
shuffle(&numbers)
for n in numbers {
newTree.addValue(n)
}
// Assert that all numbers are contained within the new tree
assert(numbers.map(newTree.containsValue).reduce(true, combine: { $0 && $1 }))
}
print("==== Node<T> ====\n")
var root: Node<Int> = .None
root.addValue(3)
root.addValue(2)
root.addValue(4)
root.addValue(1)
root.addValue(5)
print("Root depth: \(root.depth())")
print("Tree min value: \(root.minValue()) max value: \(root.maxValue())")
root.map { print($0) }
print("")
root.traverse { print($0) }
print("Breadth-first traversal:")
root.traverseBreadthFirst { print($0) }
print("Root contains 3: \(root.containsValue(3))")
testTree(Node<Int>.None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment