Created
August 25, 2019 07:03
-
-
Save sharanry/27d1616bc3c4a4f0a3bfd70c286dcc0b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/interface.jl b/src/interface.jl | |
index 4536bda..1581def 100644 | |
--- a/src/interface.jl | |
+++ b/src/interface.jl | |
@@ -98,7 +98,7 @@ julia> cb = b ∘ b; | |
julia> x = randn(2, 3) | |
2×3 Array{Float64,2}: | |
- 0.0660476 -0.77195 -1.7832 | |
+ 0.0660476 -0.77195 -1.7832 | |
-0.147743 -1.46459 0.264924 | |
julia> forward(cb, x) | |
@@ -107,10 +107,10 @@ ERROR: MethodError: no method matching +(::Array{Float64,1}, ::Float64) | |
julia> forward(cb, x, zeros(size(x, 2))) | |
(rv = [1.10887 0.32029 -0.704563; -0.639206 -1.97935 -0.243419], logabsdetjac = [0.018534, 1.46352e-5, 0.00521633]) | |
``` | |
- | |
+ | |
""" | |
forward(b::Bijector, x) = forward(b, x, zero(eltype(x))) | |
-forward(b::Bijector, x, logjac) = (rv=b(x), logabsdetjac=logjac + logabsdetjac(b, x)) | |
+forward(b::Bijector, x, logjac) = (rv=b(x), logabsdetjac=logjac .+ logabsdetjac(b, x)) | |
forward(ib::Inversed{<: Bijector}, y) = ( | |
rv=ib(y), | |
logabsdetjac=logabsdetjac(ib, y) | |
@@ -227,11 +227,11 @@ logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) | |
function _forward(f, b1::Bijector, b2::Bijector) | |
f1 = forward(b1, f.rv) | |
f2 = forward(b2, f1.rv) | |
- return (rv=f2.rv, logabsdetjac=f2.logabsdetjac + f1.logabsdetjac + f.logabsdetjac) | |
+ return (rv=f2.rv, logabsdetjac=f2.logabsdetjac .+ f1.logabsdetjac .+ f.logabsdetjac) | |
end | |
function _forward(f, b::Bijector, bs::Bijector...) | |
f1 = forward(b, f.rv) | |
- f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac + f.logabsdetjac) | |
+ f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac .+ f.logabsdetjac) | |
return _forward(f_, bs...) | |
end | |
# if `x` represents multiple elements to act on, we want to allow the user to | |
@@ -244,7 +244,7 @@ end | |
function forward(cb::Composed, x, logjac) | |
rv = x | |
logjac_ = logjac | |
- | |
+ | |
for t in cb.ts | |
res = forward(t, rv) | |
rv = res.rv | |
@@ -308,17 +308,23 @@ struct Shift{T} <: Bijector | |
a::T | |
end | |
-(b::Shift)(x) = b.a + x | |
+Shift(dims::Int, container=Array) = Shift(container(zeros(dims, 1))) | |
+ | |
+(b::Shift)(x::T) where T<:AbstractArray = b.a .+ x | |
inv(b::Shift) = Shift(-b.a) | |
-logabsdetjac(b::Shift, x::T) where T = zero(T) | |
+logabsdetjac(b::Shift, x::T) where T<:Real = zero(T) | |
+logabsdetjac(b::Shift, x::T) where T<:AbstractArray = zeros(eltype(x), size(x, 2)) | |
struct Scale{T} <: Bijector | |
a::T | |
end | |
+Scale(dims::Int, container=Array) = Scale(container(one(randn(dims, dims)))) | |
+ | |
(b::Scale)(x) = b.a * x | |
-inv(b::Scale) = Scale(b.a^(-1)) | |
-logabsdetjac(b::Scale, x) = log(abs(b.a)) | |
+inv(b::Scale) = Scale(inv(b.a)) | |
+logabsdetjac(b::Scale, x::T) where T<: Real = log(abs(b.a)) | |
+logabsdetjac(b::Scale, x::T) where T<: AbstractArray = ones(size(x, 2)) .* log(abs(det(b.a))) | |
#################### | |
# Simplex bijector # | |
@@ -404,7 +410,7 @@ function (ib::Inversed{<: SimplexBijector{Val{proj}}})(y::AbstractVector{T}) whe | |
else | |
x[K] = _clamp(one(T) - sum_tmp - y[K], ib.orig) | |
end | |
- | |
+ | |
return x | |
end | |
@@ -438,7 +444,7 @@ end | |
function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where T | |
ϵ = _eps(T) | |
lp = zero(T) | |
- | |
+ | |
K = length(x) | |
sum_tmp = zero(eltype(x)) | |
@@ -460,7 +466,7 @@ end | |
DistributionBijector(d::Distribution) | |
DistributionBijector{<: ADBackend, D}(d::Distribution) | |
-This is the default `Bijector` for a distribution. | |
+This is the default `Bijector` for a distribution. | |
It uses `link` and `invlink` to compute the transformations, and `AD` to compute | |
the `jacobian` and `logabsdetjac`. | |
@@ -503,7 +509,7 @@ const Transformed = TransformedDistribution | |
Couples distribution `d` with the bijector `b` by returning a `TransformedDistribution`. | |
-If no bijector is provided, i.e. `transformed(d)` is called, then | |
+If no bijector is provided, i.e. `transformed(d)` is called, then | |
`transformed(d, bijector(d))` is returned. | |
""" | |
transformed(d::Distribution, b::Bijector) = TransformedDistribution(d, b) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment