Skip to content

Instantly share code, notes, and snippets.

@crcrpar
Last active July 14, 2019 14:06
Show Gist options
  • Select an option

  • Save crcrpar/d5c71420cec3afddd6ec731be5a16d46 to your computer and use it in GitHub Desktop.

Select an option

Save crcrpar/d5c71420cec3afddd6ec731be5a16d46 to your computer and use it in GitHub Desktop.
// /Library/Developer/Toolchains/swift-tensorflow-DEVELOPMENT-2019-06-17-a.xctoolchain
import TensorFlow
public struct SNConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
// Avoid below errors:
// TensorFlow.Layer:2:20: note: protocol requires nested type 'Input'; do you want to add it?
// TensorFlow.Layer:3:20: note: protocol requires nested type 'Output'; do you want to add it?
public typealias Input = Tensor<Scalar>
public typealias Output = Tensor<Scalar>
// Copy of https://github.com/tensorflow/swift-apis/blob/master/Sources/TensorFlow/Layers/Convolutional.swift#L128
public var filter: Tensor<Scalar>
public var bias: Tensor<Scalar>
@noDerivative public let activation: Activation
@noDerivative public let strides: (Int, Int)
@noDerivative public let padding: Padding
@noDerivative public let dilations: (Int, Int)
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
// Spectral Normalization parameters
@noDerivative public let nPowerIteration: Int
@noDerivative public let eps: Scalar
@noDerivative public var u: Parameter<Scalar>
@noDerivative public var v: Parameter<Scalar>
// Copy of https://github.com/tensorflow/swift-apis/blob/master/Sources/TensorFlow/Layers/Convolutional.swift#L128
public init(
filter: Tensor<Scalar>,
bias: Tensor<Scalar>,
activation: @escaping Activation = identity,
strides: (Int, Int) = (1, 1),
padding: Padding = .valid,
dilations: (Int, Int) = (1, 1)
) {
self.filter = filter
self.bias = bias
self.activation = activation
self.strides = strides
self.padding = padding
self.dilations = dilations
self.nPowerIteration = 1
self.eps = 1e-12
self.u = Parameter(Tensor<Scalar>(randomNormal: [filter.shape[3], 1]))
self.v = Parameter(Tensor<Scalar>(randomNormal: [1, filter.shape[0..<3].contiguousSize]))
}
func normalize(_ x: Tensor<Scalar>, _ eps: Scalar) -> Tensor<Scalar> {
return x / (sqrt(x.squared().sum()) + eps)
}
func updateApproxVectors(_ nPowerIteration: Int, _ weightMatrix: Tensor<Scalar>) {
for _ in 0..<nPowerIteration {
v.value = normalize(u.value.transposed() • weightMatrix, eps)
u.value = normalize(weightMatrix • v.value.transposed(), eps)
}
}
@differentiable
func calcMaxSingularValue(_ weightMatrix: Tensor<Scalar>, _ u: Parameter<Scalar>, _ v: Parameter<Scalar>) -> Tensor<Scalar> {
let sigma = u.value.transposed() • weightMatrix • v.value.transposed()
return sigma
}
@differentiable
func applyingTraining(to input: Tensor<Scalar>) -> Tensor<Scalar> {
let weightMatrix = filter.reshaped(to: [filter.shape[0..<3].contiguousSize, filter.shape[3]]).transposed()
updateApproxVectors(nPowerIteration, weightMatrix)
return activation(conv2D(
input,
filter: filter / calcMaxSingularValue(weightMatrix, u, v),
strides: (1, strides.0, strides.1, 1),
padding: padding) + bias)
}
@differentiable
func applyingInference(to input: Tensor<Scalar>) -> Tensor<Scalar> {
let weightMatrix = filter.reshaped(to: [filter.shape[0..<3].contiguousSize, filter.shape[3]]).transposed()
return activation(conv2D(
input,
filter: filter / calcMaxSingularValue(weightMatrix, u, v),
strides: (1, strides.0, strides.1, 1),
padding: padding) + bias)
}
@differentiable(vjp: _vjpApplied(to:))
public func callAsFunction(_ input: Input) -> Output {
switch Context.local.learningPhase {
case .training: return applyingTraining(to: input)
case .inference: return applyingInference(to: input)
}
}
@usableFromInline
func _vjpApplied(to input: Tensor<Scalar>) ->
(Tensor<Scalar>, (Tensor<Scalar>) -> (SNConv2D<Scalar>.TangentVector, Tensor<Scalar>)) {
switch Context.local.learningPhase {
case .training:
return valueWithPullback(at: input) {
$0.applyingTraining(to: $1)
}
case .inference:
return valueWithPullback(at: input) {
$0.applyingInference(to: $1)
}
}
}
}
@crcrpar
Copy link
Copy Markdown
Author

crcrpar commented Jul 14, 2019

❯ swift SN.swift
SIL verification failed: apply doesn't have right number of arguments for function: site.getNumArguments() == substConv.getNumSILArguments()
Verifying instruction:
   %0 = argument of bb0 : $*SNConv2D<τ_0_0>.AllDifferentiableVariables // user: %3
   %1 = argument of bb0 : $Tensor<τ_0_0>         // user: %3
   %2 = argument of bb0 : $@callee_guaranteed (@guaranteed Tensor<τ_0_0>) -> (@owned SNConv2D<τ_0_0>.AllDifferentiableVariables, @owned Tensor<τ_0_0>) // user: %3
->   %3 = apply %2(%0, %1) : $@callee_guaranteed (@guaranteed Tensor<τ_0_0>) -> (@owned SNConv2D<τ_0_0>.AllDifferentiableVariables, @owned Tensor<τ_0_0>) // user: %4
     return %3 : $(SNConv2D<τ_0_0>.AllDifferentiableVariables, Tensor<τ_0_0>) // id: %4
In function:
// AD__$s10TensorFlow0A0VyxG2SN8SNConv2DV26AllDifferentiableVariablesVyx_GADIeggoo_A2dJIeggor_AA0aB13FloatingPointRzlTR_pullback_thunk
sil shared [transparent] [serialized] [thunk] @AD__$s10TensorFlow0A0VyxG2SN8SNConv2DV26AllDifferentiableVariablesVyx_GADIeggoo_A2dJIeggor_AA0aB13FloatingPointRzlTR_pullback_thunk : $@convention(thin) <τ_0_0 where τ_0_0 : TensorFlowFloatingPoint> (@guaranteed Tensor<τ_0_0>, @guaranteed @callee_guaranteed (@guaranteed Tensor<τ_0_0>) -> (@owned SNConv2D<τ_0_0>.AllDifferentiableVariables, @owned Tensor<τ_0_0>)) -> (@owned Tensor<τ_0_0>, @out SNConv2D<τ_0_0>.AllDifferentiableVariables) {
// %0                                             // user: %3
// %1                                             // user: %3
// %2                                             // user: %3
bb0(%0 : $*SNConv2D<τ_0_0>.AllDifferentiableVariables, %1 : $Tensor<τ_0_0>, %2 : $@callee_guaranteed (@guaranteed Tensor<τ_0_0>) -> (@owned SNConv2D<τ_0_0>.AllDifferentiableVariables, @owned Tensor<τ_0_0>)):
  %3 = apply %2(%0, %1) : $@callee_guaranteed (@guaranteed Tensor<τ_0_0>) -> (@owned SNConv2D<τ_0_0>.AllDifferentiableVariables, @owned Tensor<τ_0_0>) // user: %4
  return %3 : $(SNConv2D<τ_0_0>.AllDifferentiableVariables, Tensor<τ_0_0>) // id: %4
} // end sil function 'AD__$s10TensorFlow0A0VyxG2SN8SNConv2DV26AllDifferentiableVariablesVyx_GADIeggoo_A2dJIeggor_AA0aB13FloatingPointRzlTR_pullback_thunk'

Stack dump:
0.      Program arguments: /Library/Developer/Toolchains/swift-tensorflow-DEVELOPMENT-2019-06-17-a.xctoolchain/usr/bin/swift -frontend -interpret SN.swift -enable-objc-interop -sdk /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.14.sdk -color-diagnostics -module-name SN
1.      Swift version 5.1-dev (LLVM e56fafcd29, Swift 8d39256042)
2.      While verifying SIL function "@AD__$s10TensorFlow0A0VyxG2SN8SNConv2DV26AllDifferentiableVariablesVyx_GADIeggoo_A2dJIeggor_AA0aB13FloatingPointRzlTR_pullback_thunk".
 for '_vjpApplied(to:)' (at SN.swift:99:5)
0  swift                    0x0000000103da9fc5 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 37
1  swift                    0x0000000103da92b5 llvm::sys::RunSignalHandlers() + 85
2  swift                    0x0000000103daa5a8 SignalHandler(int) + 264
3  libsystem_platform.dylib 0x00007fff726a2b5d _sigtramp + 29
4  swift                    0x00000001060a6d00 llvm::dbgs()::thestrm + 0
5  libsystem_c.dylib        0x00007fff7255c6a6 abort + 127
6  swift                    0x0000000100d66118 (anonymous namespace)::SILVerifier::_require(bool, llvm::Twine const&, std::__1::function<void ()> const&) + 616
7  swift                    0x0000000100d82e2c (anonymous namespace)::SILVerifier::checkFullApplySite(swift::FullApplySite) + 636
8  swift                    0x0000000100d6a406 swift::SILInstructionVisitor<(anonymous namespace)::SILVerifier, void>::visit(swift::SILInstruction*) + 5062
9  swift                    0x0000000100d67dcc (anonymous namespace)::SILVerifier::visitSILBasicBlock(swift::SILBasicBlock*) + 1484
10 swift                    0x0000000100d62517 swift::SILFunction::verify(bool) const + 8199
11 swift                    0x0000000100d6544a swift::SILModule::verify() const + 202
12 swift                    0x00000001007ad9c1 swift::Lowering::SILGenModule::~SILGenModule() + 33
13 swift                    0x00000001007b8354 swift::SILModule::constructSIL(swift::ModuleDecl*, swift::SILOptions&, swift::FileUnit*) + 1220
14 swift                    0x00000001007b83c0 swift::performSILGeneration(swift::ModuleDecl*, swift::SILOptions&) + 16
15 swift                    0x00000001004dac0a performCompile(swift::CompilerInstance&, swift::CompilerInvocation&, llvm::ArrayRef<char const*>, int&, swift::FrontendObserver*, swift::UnifiedStatsReporter*) + 10282
16 swift                    0x00000001004d7401 swift::performFrontend(llvm::ArrayRef<char const*>, char const*, void*, swift::FrontendObserver*) + 3025
17 swift                    0x00000001004808d9 main + 729
18 libdyld.dylib            0x00007fff724b73d5 start + 1
19 libdyld.dylib            0x000000000000000a start + 2377419830
zsh: abort      swift SN.swift

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment