Created
June 12, 2019 01:27
-
-
Save torfjelde/337a959deea5bae4075c99a668a9213f to your computer and use it in GitHub Desktop.
Possible Bijectors.jl interface
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributions, Bijectors | |
using ForwardDiff | |
using Tracker | |
using Turing | |
import Random: AbstractRNG | |
import Distributions: logpdf, rand, rand!, _rand!, _logpdf | |
abstract type Bijector end | |
abstract type CustomBijector{AD} <: Bijector end | |
"Computes the transformation." | |
transform(b::Bijector, x) = begin end | |
transform(b::Bijector) = x -> transform(b, x) | |
"Computes the inverse transformation of the Bijector." | |
inverse(b::Bijector, y) = begin end | |
inverse(b::Bijector) = y -> inverse(b, y) | |
# TODO: rename? a bit of a mouthful | |
# TODO: allow batch-computation, especially for univariate case | |
"Computes the determinant of the Jacobian of the inverse-transformation." | |
detloginvjac(b::Bijector, y) = begin end | |
detloginvjac(b::CustomBijector{AD}, y::T) where {AD <: Turing.Core.ForwardDiffAD, T <: Real} = log(abs(ForwardDiff.derivative(z -> inverse(b, z), y))) | |
detloginvjac(b::CustomBijector{AD}, y::AbstractVector) where AD <: Turing.Core.ForwardDiffAD = logabsdet(ForwardDiff.jacobian(z -> inverse(b, z), y))[1] | |
# FIXME: untrack? i.e. `Tracker.data(...)` | |
detloginvjac(b::CustomBijector{AD}, y::T) where {AD <: Turing.Core.TrackerAD, T <: Real} = log(abs(Tracker.gradient(z -> inverse(b, z[1]), [y])[1][1])) | |
detloginvjac(b::CustomBijector{AD}, y::AbstractVector) where AD <: Turing.Core.TrackerAD = logabsdet(Tracker.jacobian(z -> inverse(b, z), y))[1] | |
# Example bijector | |
struct Identity <: Bijector end | |
transform(::Identity, x) = x | |
inverse(::Identity, y) = y | |
detloginvjac(::Identity, y::T) where T <: Real = one(T) | |
detloginvjac(::Identity, y::AbstractVector{T}) where T <: Real = one(T) | |
# Transformed distributions | |
struct UnivariateTransformed{D, B} <: Distribution{Univariate, Continuous} where {D <: UnivariateDistribution, B <: Bijector} | |
dist::D | |
transform::B | |
end | |
struct MultivariateTransformed{D, B} <: Distribution{Multivariate, Continuous} where {D <: MultivariateDistribution, B <: Bijector} | |
dist::D | |
transform::B | |
end | |
# implement these on a case-by-case basis, e.g. `PDMatDistribution = Union{InverseWishart, Wishart}` | |
transformed(d::UnivariateDistribution, b::Bijector) = UnivariateTransformed(d, b) | |
transformed(d::MultivariateDistribution, b::Bijector) = MultivariateTransformed(d, b) | |
# Example of specific distribution impl | |
transformed(d::Normal) = transformed(d, Identity()) | |
# size | |
Base.length(td::MultivariateTransformed) = length(td.dist) | |
# logp | |
logpdf(td::UnivariateTransformed, x::T where T <: Real) = begin | |
logpdf(td.dist, inverse(td.transform, x)) .+ detloginvjac(td.transform, x) | |
end | |
_logpdf(td::MultivariateTransformed, x::AbstractVector{T} where T <: Real) = begin | |
logpdf(td.dist, inverse(td.transform, x)) .+ detloginvjac(td.transform, x) | |
end | |
logpdf_with_jac(td::MultivariateTransformed, x::AbstractVector{T} where T <: Real) = begin | |
z = detloginvjac(td.transform, x) | |
return (logpdf(td.dist, inverse(td.transform, x)) .+ z, z) | |
end | |
# rand | |
rand(rng::AbstractRNG, td::UnivariateTransformed) = transform(td.transform, rand(td.dist)) | |
_rand!(rng::AbstractRNG, td::MultivariateTransformed, x::AbstractVector{T} where T <: Real) = begin | |
rand!(rng, td.dist, x) | |
y = transform(td.transform, x) | |
copyto!(x, y) | |
end | |
######################## | |
# Some simple examples # | |
######################## | |
# Example: `Univariate` | |
d = Normal() | |
td = transformed(d) | |
@info "Univariate" rand(td, 10) | |
@info "Univariate" logpdf.(td, rand(td, 10)) | |
# Example: `Multivariate` | |
d = MvNormal(zeros(5), ones(5)) | |
td = transformed(d, Identity()) | |
@info "Multivariate" rand(td, 10) | |
@info "Multivariate" logpdf(td, rand(td)) | |
x = rand(td) | |
@assert transform(td.transform, x) == x | |
# Example: Custom stuff | |
struct PositiveTransform{AD} <: CustomBijector{AD} end | |
transform(::PositiveTransform, x) = log(x) | |
inverse(::PositiveTransform, y) = exp(y) | |
t = PositiveTransform{Turing.Core.ADBackend(:reverse_diff)}() | |
# t = PositiveTransform{Turing.Core.ADBackend(:forward_diff)}() # <= also works | |
d = InverseGamma() | |
td = transformed(d, t) | |
rand(td) | |
logpdf.(td, rand(td, 10)) | |
inverse(td.transform, rand(td)) | |
detloginvjac(td.transform, rand(td)) | |
log(abs(Tracker.gradient(z -> inverse(td.transform, z[1]), [y])[1][1])) | |
x = rand(td.dist) | |
y = transform(td.transform, x) | |
@assert inverse(td.transform)(transform(td.transform, x)) == x "f ∘ f⁻¹ ≠ identity" | |
@assert logpdf(td, y) == logpdf(td.dist, x) + log(x) "autodiff detloginvjac ≠ true detloginvjac" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment