Zygote BatchNorm implementation
using Zygote, Statistics, Flux
# We modify (the implementation of) batchnorm to be more ammenable to CPUs pretending to be TPUs.
struct ZygoteBatchNorm{F,V,W}
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ::W # moving std
function ZygoteBatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i),
initγ = (i) -> ones(Float32, i),
ϵ = 1f-5,
momentum = 0.1f0)
return ZygoteBatchNorm(λ, initβ(chs), initγ(chs), zeros(Float32, chs), ones(Float32, chs), ϵ, momentum)
function compute_affine_shape(x)
return ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
@Zygote.nograd compute_affine_shape
# This is a bit of a trick. We use Zygote's backward pass infrastructure
# to backpropagate the batchnorm statistics to the parameter update
function batchnorm_statistics(bn_μ, bn_σ, bn_ε, x)
affine_shape = compute_affine_shape(x)
μ = reshape(bn_μ, affine_shape...)
σ = reshape(bn_σ, affine_shape...)
x̂ = (x .- μ)./σ
return x̂
@Zygote.adjoint function batchnorm_statistics(bn_μ, bn_σ, bn_ϵ, x)
ϵ = convert(eltype(x), bn_ϵ)
axes = tuple(1:(ndims(x)-2)..., ndims(x))
m = prod(size.(Ref(x), axes))
# Calculate μ and σ for the "forward pass"
μ = mean(x, dims = axes)
meansub = (x .- μ)
σ = sqrt.(mean(meansub.^2, dims = axes) .+ ϵ)
x̂ = meansub./σ
# Create dimensionality-dropped versions that will be closed over
μ_dropped=dropdims(μ, dims = axes)
σ_dropped=dropdims(σ, dims = axes) .* convert(eltype(x), m) / convert(eltype(x), m - 1)
backward_pass = Δ -> begin
Δ_x = (Δ .- mean(Δ, dims=axes) .- x̂ .* mean(Δ .* x̂, dims=axes))./σ
return (μ_dropped, σ_dropped, nothing, Δ_x)
# Return "forward pass" calculation and closure that returns our "gradients"
return x̂, backward_pass
function (BN::ZygoteBatchNorm)(x)
affine_shape = compute_affine_shape(x)
# This chunk of the computation will be overriden through Zygote trickery;
# we set things up such that when calculating the _forward pass_ during a
# Zygote run (e.g. we know we will backward pass later) the forward pass
# keeps track of certain statistics, and provides those to the backward
# pass when it needs to calculate. If we are not being run within Zygote,
# we do not calculate those statistics and only use the saved `μ` and `σ`.
x̂ = batchnorm_statistics(BN.μ, BN.σ, BN.ϵ, x)
γ = reshape(BN.γ, affine_shape...)
β = reshape(BN.β, affine_shape...)
return BN.λ.(γ .* x̂ .+ β)
