Skip to content

Instantly share code, notes, and snippets.

@dabrahams
Last active January 9, 2025 20:30
Show Gist options
  • Save dabrahams/7ebad8193538369a170b184b324cc21a to your computer and use it in GitHub Desktop.
Save dabrahams/7ebad8193538369a170b184b324cc21a to your computer and use it in GitHub Desktop.
indirect enum Tree {
case none
case left(Tree)
case right(Tree)
case both(Tree, Tree)
var left: Tree? {
get {
switch self {
case .left(let x), .both(let x, _): x
default: nil
}
}
}
var right: Tree? {
switch self {
case .right(let x), .both(_, let x): x
default: nil
}
}
init (_ left: Tree?, _ right: Tree?) {
switch (left, right) {
case (nil, nil): self = .none
case (.some(let x), nil): self = .left(x)
case (nil, .some(let x)): self = .right(x)
case (.some(let l), .some(let r)): self = .both(l, r)
}
}
static func with (_ depth: Int) -> Tree {
return (depth == 0)
? Tree (nil, nil)
: Tree ( with(depth-1), with(depth-1))
}
func nodeCount () -> Int {
return (left == nil)
? 1
: 1 + left!.nodeCount() + right!.nodeCount()
}
mutating func clear () {
self = .none
}
}
func main(_ n: Int) {
let minDepth = 4
let maxDepth = minDepth + 2 > n ? minDepth + 2 : n
let stretchDepth = maxDepth + 1
stretch(stretchDepth)
var longLivedTree = Tree.with(maxDepth)
for depth in stride(from: minDepth, to: stretchDepth, by: 2){
let iterations = 1 << (maxDepth - depth + minDepth)
var sum = 0
for _ in 1...iterations {
sum += count(depth)
}
print("\(iterations)\t trees of depth \(depth)\t check: \(sum)")
}
let count = longLivedTree.nodeCount()
longLivedTree.clear()
print("long lived tree of depth \(maxDepth)\t check: \(count)")
}
func stretch(_ depth: Int) {
print("stretch tree of depth \(depth)\t check: \(count(depth))")
}
func count(_ depth: Int) -> Int {
var t = Tree.with(depth)
let c = t.nodeCount()
t.clear()
return c
}
main(
(CommandLine.argc > 1)
? Int(CommandLine.arguments[1])!
: 10 )
typealias TreePointer = UnsafeMutablePointer<Tree>?
struct Tree {
var left: TreePointer;
var right: TreePointer;
};
func new_tree(_ left: TreePointer, _ right: TreePointer) -> TreePointer {
let new: TreePointer = UnsafeMutablePointer<Tree>.allocate(capacity: 1);
new!.initialize(to: Tree(left: left, right: right))
return new;
}
func tree_with(_ depth: Int) -> TreePointer{
return (depth == 0)
? new_tree(nil, nil)
: new_tree( tree_with(depth-1), tree_with(depth-1));
}
func node_count(_ t: TreePointer) -> Int {
if (t!.pointee.left == nil)
{ return 1; }
else
{return 1 + node_count(t!.pointee.left) + node_count(t!.pointee.right);}
}
func clear(_ t: TreePointer) {
if (t!.pointee.left != nil) {
clear(t!.pointee.left);
t!.pointee.left!.deinitialize(count: 1)
t!.pointee.left!.deallocate()
clear(t!.pointee.right);
t!.pointee.right!.deinitialize(count: 1)
t!.pointee.right!.deallocate()
}
}
func count(depth: Int) -> Int {
let t = tree_with(depth);
let c = node_count(t);
clear(t);
return c;
}
func main(_ n: Int) {
let minDepth = 4
let maxDepth = minDepth + 2 > n ? minDepth + 2 : n
let stretchDepth = maxDepth + 1
stretch(stretchDepth)
let longLivedTree = tree_with(maxDepth)
for depth in stride(from: minDepth, to: stretchDepth, by: 2){
let iterations = 1 << (maxDepth - depth + minDepth)
var sum = 0
for _ in 1...iterations {
sum += count(depth)
}
print("\(iterations)\t trees of depth \(depth)\t check: \(sum)")
}
let count = node_count(longLivedTree)
clear(longLivedTree)
print("long lived tree of depth \(maxDepth)\t check: \(count)")
}
func stretch(_ depth: Int) {
print("stretch tree of depth \(depth)\t check: \(count(depth))")
}
func count(_ depth: Int) -> Int {
let t = tree_with(depth)
let c = node_count(t)
clear(t)
return c
}
main(
(CommandLine.argc > 1)
? Int(CommandLine.arguments[1])!
: 10 )
@dabrahams
Copy link
Author

dabrahams commented Jan 9, 2025

With command-line argument 21

Original benchmark: 28.51s
After removing body of needless clear(): 23.74s
After converting to value semantics: 13.08s (ratio 0.458)
A real transliteration, using raw UnsafePointer: 10.72s (ratio 0.376)

@dabrahams
Copy link
Author

Oh and rewriting node_count this way brings the value semantics version (without UnsafePointer) down to 12.61s:

   func nodeCount () -> Int {
     switch self {
     case .none, .right: 1
     case .both(let l, let r): 1 + l.nodeCount() + r.nodeCount()
     default: unreachable()
     }
   }

where:

func unreachable() -> Never {
    return unsafeBitCast((), to: Never.self)
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment