Skip to content

Instantly share code, notes, and snippets.

@marcrasi
Created June 27, 2019 23:18
Show Gist options
  • Save marcrasi/ec5a8fd5a4eef18f2cd353c9993d2b94 to your computer and use it in GitHub Desktop.
Save marcrasi/ec5a8fd5a4eef18f2cd353c9993d2b94 to your computer and use it in GitHub Desktop.
import TensorFlow
protocol MyLayer: Differentiable {
associatedtype Input: Differentiable
@differentiable
func forward(_ x: Input) -> Tensor<Float>
}
extension MyLayer {
@differentiable(vjp: callGrad)
func callForward(_ x: Input) -> Tensor<Float> {
return forward(x)
}
func callGrad(_ x: Input) -> (Tensor<Float>, (Tensor<Float>) -> (Self.TangentVector, Input.TangentVector)) {
func pb(x: Tensor<Float>) -> (TangentVector, Input.TangentVector) {
return (TangentVector.zero, Input.TangentVector.zero)
}
return (Tensor(0), pb)
}
}
struct Dense: MyLayer {
var a: Tensor<Float> = Tensor(1.0)
var b: Tensor<Float> = Tensor(0.0)
@differentiable
func forward(_ x: Tensor<Float>) -> Tensor<Float> {
return Tensor(0)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment