Created
June 4, 2019 17:39
-
-
Save marcrasi/989b2014416bc2948ed1c03b1352e5c8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// MARK: - Some differentiable array manipulation functions used in the algorithms. | |
extension Array where Element: Differentiable { | |
@differentiable(vjp: _vjpSwappedAt) | |
func swappedAt(_ i: Int, _ j: Int) -> Array { | |
var tmp = self | |
tmp.swapAt(i, j) | |
return tmp | |
} | |
func _vjpSwappedAt(_ i: Int, _ j: Int) -> (Array, (TangentVector) -> TangentVector) { | |
return (swappedAt(i, j), { TangentVector($0.base.swappedAt(i, j)) }) | |
} | |
@differentiable(vjp: _vjpDroppedFirst) | |
func droppedFirst() -> Array { | |
return Array(self.dropFirst()) | |
} | |
func _vjpDroppedFirst() -> (Array, (TangentVector) -> TangentVector) { | |
return (droppedFirst(), { TangentVector([Element.TangentVector.zero] + $0.base) }) | |
} | |
@differentiable(vjp: _vjpAppending) | |
func appending(_ element: Element) -> Array { | |
var tmp = self | |
tmp.append(element) | |
return tmp | |
} | |
func _vjpAppending(_ element: Element) -> ([Element], (TangentVector) -> (TangentVector, Element.TangentVector)) { | |
func pb(_ v: TangentVector) -> (TangentVector, Element.TangentVector) { | |
return (TangentVector(Array<Element.TangentVector>(v.base.dropLast())), v.base[v.base.count - 1]) | |
} | |
return (appending(element), pb) | |
} | |
@differentiable(vjp: _vjpMakeSingle) | |
static func makeSingle(_ element: Element) -> Array { | |
return [element] | |
} | |
static func _vjpMakeSingle(_ element: Element) -> (Array, (TangentVector) -> Element.TangentVector) { | |
return ([element], { v in | |
precondition(v.base.count == 1) | |
return v.base[0] | |
}) | |
} | |
} | |
// MARK: - Custom VJP for stdlib sort. | |
@differentiable(vjp: _vjpSorted) | |
func sorted(_ array: [Double]) -> [Double] { | |
return array.sorted() | |
} | |
func _vjpSorted(_ array: [Double]) -> ([Double], (Array<Double>.DifferentiableView) -> Array<Double>.DifferentiableView) { | |
let sort = array.enumerated().sorted(by: { $0.element < $1.element }) | |
let sorted = sort.map { $0.element } | |
let permutation = sort.map { $0.offset } | |
return (sorted, { v in | |
var result = Array(repeating: 0.0, count: v.base.count) | |
for (i, j) in permutation.enumerated() { | |
result[j] = v.base[i] | |
} | |
return Array<Double>.DifferentiableView(result) | |
}) | |
} | |
let arrayToSort: [Double] = [7, 2, 4, 1, 8, 3, 0, 9] | |
var vectorsToPullBack: [[Double]] = [] | |
for i in 0..<arrayToSort.count { | |
var v = Array(repeating: 0.0, count: arrayToSort.count) | |
v[i] = 1 | |
vectorsToPullBack.append(v) | |
} | |
let (value, pb) = valueWithPullback(at: arrayToSort, in: sorted) | |
print("USING CUSTOM DERIVATIVE FOR SORT") | |
print(value) | |
for v in vectorsToPullBack { | |
print(pb(Array.DifferentiableView(v))) | |
} | |
print("") | |
// MARK: - Selection sort. | |
func argMax(_ array: [Double]) -> Int { | |
var result: Int = 0 | |
var max: Double = array[0] | |
for (index, val) in array.enumerated() { | |
if val > max { | |
result = index | |
max = val | |
} | |
} | |
return result | |
} | |
func selectionSort(_ array: [Double]) -> [Double] { | |
if array.count <= 1 { | |
return array | |
} else { | |
let next = array.swappedAt(0, argMax(array.withoutDerivative())) | |
return selectionSort(next.droppedFirst()).appending(next[0]) | |
} | |
} | |
let (value2, pb2) = valueWithPullback(at: arrayToSort, in: selectionSort) | |
print("USING AUTOMATICALLY COMPUTED DERIVATIVE OF SELECTION SORT") | |
print(value2) | |
if value2 != value { | |
print(" oh no, that one is wrong") | |
} | |
for v in vectorsToPullBack { | |
print(pb2(Array.DifferentiableView(v))) | |
if pb2(Array.DifferentiableView(v)) != pb(Array.DifferentiableView(v)) { | |
print(" oh no, that one is wrong") | |
} | |
} | |
print("") | |
// MARK: - Quicksort. | |
extension Array where Element : Differentiable { | |
func filter(_ predicate: (Element) -> Bool, _ start: Int) -> Array { | |
if start == count { | |
return [] | |
} | |
if predicate(self[start]) { | |
return filter(predicate, start + 1).appending(self[start]) | |
} else { | |
return filter(predicate, start + 1) | |
} | |
} | |
} | |
func qsort(_ array: [Double]) -> [Double] { | |
if array.count <= 1 { | |
return array | |
} | |
let pivot = array[0] | |
let pivotWD = pivot.withoutDerivative() | |
let l = array.filter({ $0 < pivotWD }, 1) | |
let r = array.filter({ $0 >= pivotWD }, 1) | |
return qsort(l) + Array.makeSingle(pivot) + qsort(r) | |
} | |
let (value3, pb3) = valueWithPullback(at: arrayToSort, in: qsort) | |
print("USING AUTOMATICALLY COMPUTED DERIVATIVE OF QUICK SORT") | |
print(value3) | |
if value3 != value { | |
print(" oh no, that one is wrong") | |
} | |
for v in vectorsToPullBack { | |
print(pb3(Array.DifferentiableView(v))) | |
if pb3(Array.DifferentiableView(v)) != pb(Array.DifferentiableView(v)) { | |
print(" oh no, that one is wrong") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment