Skip to content

Instantly share code, notes, and snippets.

@cossio
Last active May 12, 2020 11:24
Show Gist options
  • Save cossio/34b61ab81c2aef1af099561aa346f037 to your computer and use it in GitHub Desktop.
Save cossio/34b61ab81c2aef1af099561aa346f037 to your computer and use it in GitHub Desktop.
#= reparameterized truncated normal =#
using Random
using Zygote, Distributions, Plots, SpecialFunctions
import Zygote: @adjoint, Numeric
import Base.Broadcast: broadcasted
"""
ndtr(a)
Gaussian cumulative distribution function.
https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.ndtr.html
"""
function ndtr(a::Real)
SQRT1_2 = inv(√oftype(a, 2))
x = a * SQRT1_2
z = abs(x)
if z < SQRT1_2
y = 1/2 + erf(x) / 2
else
y = erfc(z) / 2
if x > 0
y = 1 - y
end
end
return y
end
function standardize(x, d::Normal)
result = (x - d.μ) / d.σ
return isfinite(x) ? result : oftype(result, x)
end
"""
u2tn(u, dist::Truncated{<:Normal})
Transforms a uniform random variable `u` in the interval (0, 1) into a
truncated normal random variable with the given parameters.
"""
function u2tn(u, d::Truncated{<:Normal})
@assert 0 < u < 1
α = standardize(d.lower, d.untruncated)
β = standardize(d.upper, d.untruncated)
return √2 * erfinv((1-u) * erf(α/√2) + u * erf(β/√2))
end
"""
tn2u(z, a, b, μ, σ)
Transforms a truncated normal random variable `z` with distribution `d`, into a
uniform random variable in the interval (0, 1).
"""
function tn2u(z, d::Truncated{<:Normal})
α = standardize(d.lower, d.untruncated)
β = standardize(d.upper, d.untruncated)
ζ = standardize(z, d.untruncated)
return (ndtr(ζ) - ndtr(α)) / (ndtr(β) - ndtr(α))
end
ValOrArr{T,N} = Union{T,AbstractArray{T,N}}
untruncated(d::Truncated) = d.untruncated
lower_bound(d::Truncated) = d.lower
upper_bound(d::Truncated) = d.upper
function tnrand(rng::Random.AbstractRNG, a::Real, b::Real, μ::Real = 0, σ::Real = 1)
rand(rng, truncated(Normal(μ, σ), a, b))
end
function tnrand(a::Real, b::Real, μ::Real = 0, σ::Real = 1)
tnrand(Random.GLOBAL_RNG, a, b, μ, σ)
end
@adjoint function Base.rand(rng::Random.AbstractRNG, d::Truncated{<:Normal})
z = rand(rng, d)
α = standardize(d.lower, d.untruncated)
β = standardize(d.upper, d.untruncated)
ζ = standardize(z, d.untruncated)
# u ~ Uniform(0,1)
u = tn2u(z, d)
#u = clamp(u, eps(zero(u)), 1 - eps(one(u)))
dα = exp((ζ^2 - α^2)/2 + log1p(-u))
dβ = exp((ζ^2 - β^2)/2 + log(u))
da = dα / d.untruncated.σ
db = dβ / d.untruncated.σ
dμ = -(da + db)
αda = ifelse(iszero(da), zero(α * da), α * da)
βdb = ifelse(iszero(db), zero(β * db), β * db)
dσ = -(αda + βdb)
back(δ) = (nothing, (untruncated = (μ = δ * dμ, σ = δ * dσ),
lower = δ * da, upper = δ * db,
lcdf = nothing, ucdf = nothing,
tp = nothing, logtp = nothing))
z, back
end
@adjoint function broadcasted(::typeof(rand), d::ValOrArr{<:Truncated{<:Normal}})
z = rand.(d)
# standardized
α = @. standardize(lower_bound(d), untruncated(d))
β = @. standardize(upper_bound(d), untruncated(d))
ζ = @. standardize(z, untruncated(d))
# u ~ Uniform(0,1)
u = @. tn2u(z, d)
#u = @. clamp(u, eps(zero(eltype(μ))), 1 - eps(one(eltype(μ))))
dα = @. exp((ζ^2 - α^2)/2 + log1p(-u))
dβ = @. exp((ζ^2 - β^2)/2 + log(u))
da = @. dα / std(untruncated(d))
db = @. dβ / std(untruncated(d))
αda = @. ifelse(iszero(da), zero(α * da), α * da)
βdb = @. ifelse(iszero(db), zero(β * db), β * db)
dμ = @. -(da + db)
dσ = @. -(αda + βdb)
back(δ) = (nothing, (untruncated = (μ = δ .* dμ, σ = δ .* dσ),
lower = δ .* da, upper = δ .* db,
lcdf = nothing, ucdf = nothing,
tp = nothing, logtp = nothing))
z, back
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment