Last active
April 11, 2019 04:01
-
-
Save staticfloat/a509b1e1cb1fb556028779722c2531e6 to your computer and use it in GitHub Desktop.
Zygote BatchNorm implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Zygote, Statistics, Flux | |
# We modify (the implementation of) batchnorm to be more ammenable to CPUs pretending to be TPUs. | |
struct ZygoteBatchNorm{F,V,W} | |
λ::F # activation function | |
β::V # bias | |
γ::V # scale | |
μ::W # moving mean | |
σ::W # moving std | |
ϵ::Float32 | |
momentum::Float32 | |
end | |
function ZygoteBatchNorm(chs::Integer, λ = identity; | |
initβ = (i) -> zeros(Float32, i), | |
initγ = (i) -> ones(Float32, i), | |
ϵ = 1f-5, | |
momentum = 0.1f0) | |
return ZygoteBatchNorm(λ, initβ(chs), initγ(chs), zeros(Float32, chs), ones(Float32, chs), ϵ, momentum) | |
end | |
function compute_affine_shape(x) | |
return ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x)) | |
end | |
@Zygote.nograd compute_affine_shape | |
# This is a bit of a trick. We use Zygote's backward pass infrastructure | |
# to backpropagate the batchnorm statistics to the parameter update | |
function batchnorm_statistics(bn_μ, bn_σ, bn_ε, x) | |
affine_shape = compute_affine_shape(x) | |
μ = reshape(bn_μ, affine_shape...) | |
σ = reshape(bn_σ, affine_shape...) | |
x̂ = (x .- μ)./σ | |
return x̂ | |
end | |
@Zygote.adjoint function batchnorm_statistics(bn_μ, bn_σ, bn_ϵ, x) | |
ϵ = convert(eltype(x), bn_ϵ) | |
axes = tuple(1:(ndims(x)-2)..., ndims(x)) | |
m = prod(size.(Ref(x), axes)) | |
# Calculate μ and σ for the "forward pass" | |
μ = mean(x, dims = axes) | |
meansub = (x .- μ) | |
σ = sqrt.(mean(meansub.^2, dims = axes) .+ ϵ) | |
x̂ = meansub./σ | |
# Create dimensionality-dropped versions that will be closed over | |
μ_dropped=dropdims(μ, dims = axes) | |
σ_dropped=dropdims(σ, dims = axes) .* convert(eltype(x), m) / convert(eltype(x), m - 1) | |
backward_pass = Δ -> begin | |
# https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html | |
Δ_x = (Δ .- mean(Δ, dims=axes) .- x̂ .* mean(Δ .* x̂, dims=axes))./σ | |
return (μ_dropped, σ_dropped, nothing, Δ_x) | |
end | |
# Return "forward pass" calculation and closure that returns our "gradients" | |
return x̂, backward_pass | |
end | |
function (BN::ZygoteBatchNorm)(x) | |
affine_shape = compute_affine_shape(x) | |
# This chunk of the computation will be overriden through Zygote trickery; | |
# we set things up such that when calculating the _forward pass_ during a | |
# Zygote run (e.g. we know we will backward pass later) the forward pass | |
# keeps track of certain statistics, and provides those to the backward | |
# pass when it needs to calculate. If we are not being run within Zygote, | |
# we do not calculate those statistics and only use the saved `μ` and `σ`. | |
x̂ = batchnorm_statistics(BN.μ, BN.σ, BN.ϵ, x) | |
γ = reshape(BN.γ, affine_shape...) | |
β = reshape(BN.β, affine_shape...) | |
return BN.λ.(γ .* x̂ .+ β) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment