|
// `Differentiable` and `VectorNumeric` from: |
|
// https://github.com/apple/swift/blob/tensorflow/stdlib/public/core/AutoDiff.swift |
|
|
|
public protocol Differentiable { |
|
associatedtype TangentVector: Differentiable & AdditiveArithmetic |
|
where TangentVector.TangentVector == TangentVector |
|
mutating func move(along direction: TangentVector) |
|
} |
|
|
|
public protocol VectorProtocol : AdditiveArithmetic { |
|
associatedtype VectorSpaceScalar : AdditiveArithmetic |
|
func scaled(by scalar: VectorSpaceScalar) -> Self |
|
mutating func scale(by scalar: VectorSpaceScalar) |
|
} |
|
public extension VectorProtocol { |
|
mutating func scale(by scalar: VectorSpaceScalar) { |
|
self = scaled(by: scalar) |
|
} |
|
static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self { |
|
lhs.scaled(by: rhs) |
|
} |
|
static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self { |
|
rhs.scaled(by: lhs) |
|
} |
|
static func *= (lhs: inout Self, rhs: VectorSpaceScalar) { |
|
lhs.scale(by: rhs) |
|
} |
|
} |
|
public extension VectorProtocol where VectorSpaceScalar: SignedNumeric { |
|
static prefix func - (x: Self) -> Self { |
|
.zero - x |
|
} |
|
} |
|
|
|
/// Reference: "Adam - A Method for Stochastic Optimization". |
|
/// https://arxiv.org/abs/1412.6980v8 |
|
@available(macOS 9999, *) |
|
public class Adam<Model: Differentiable> |
|
where |
|
Model.TangentVector: VectorProtocol & ElementaryFunctions, |
|
Model.TangentVector.VectorSpaceScalar: BinaryFloatingPoint & ElementaryFunctions |
|
{ |
|
public typealias Scalar = Model.TangentVector.VectorSpaceScalar |
|
public var learningRate: Scalar |
|
public var beta1: Scalar |
|
public var beta2: Scalar |
|
public var epsilon: Scalar |
|
public var decay: Scalar |
|
public var step: Int = 0 |
|
public var firstMoments: Model.TangentVector = .zero |
|
public var secondMoments: Model.TangentVector = .zero |
|
|
|
public init( |
|
for model: __shared Model, |
|
learningRate: Scalar = 1e-3, |
|
beta1: Scalar = 0.9, |
|
beta2: Scalar = 0.999, |
|
epsilon: Scalar = 1e-8, |
|
decay: Scalar = 0 |
|
) { |
|
precondition(learningRate >= 0, "Learning rate must be non-negative") |
|
precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1") |
|
precondition(0 <= beta2 && beta2 <= 1, "Beta parameter must be between 0 and 1") |
|
precondition(decay >= 0, "Learning rate decay must be non-negative") |
|
|
|
self.learningRate = learningRate |
|
self.beta1 = beta1 |
|
self.beta2 = beta2 |
|
self.epsilon = epsilon |
|
self.decay = decay |
|
} |
|
|
|
public func update(_ model: inout Model, along direction: Model.TangentVector) { |
|
step += 1 |
|
let learningRate = self.learningRate / (1 + decay * Scalar(step)) |
|
// Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check |
|
// this expression in reasonable time" error. |
|
var stepSize = learningRate * Scalar.sqrt(1 - Scalar.pow(beta2, step)) |
|
stepSize = stepSize / (1 - Scalar.pow(beta1, step)) |
|
firstMoments = firstMoments * beta1 + (1 - beta1) * direction |
|
secondMoments = secondMoments * beta2 + (1 - beta2) * Model.TangentVector.pow(direction, 2) |
|
model.move(along: -stepSize * firstMoments / (Model.TangentVector.sqrt(secondMoments) + epsilon)) |
|
} |
|
} |