Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Last active September 12, 2020 10:28
Show Gist options
  • Save torfjelde/8675bba686afdf693476ae1c70f516d3 to your computer and use it in GitHub Desktop.
Save torfjelde/8675bba686afdf693476ae1c70f516d3 to your computer and use it in GitHub Desktop.
"""
@bijector function f(b::Bijector, x) ... end
Takes the method `forward` and uses it to define both `transform`
and `logabsdetjac`, while ensuring that any shared computation is
taken advantage of in `forward`.
"""
macro bijector(expr)
def = MacroTools.splitdef(expr)
body = def[:body]
args = def[:args]
whereparams = def[:whereparams]
# extract the input variables
bijector_arg, input_arg = args
# Define the `(b::Bijector)(x::T)` signature
bijector_call_expr = if !isempty(whereparams)
quote
($bijector_arg)($input_arg) where {$whereparams...} = $(Bijectors).transform($bijector_arg, $input_arg)
end
else
quote
($bijector_arg)($input_arg) = $(Bijectors).transform($bijector_arg, $input_arg)
end
end
# Figure out what is shared, and what isn't
shared_exprs = [] # beginning of `forward`
transform_exprs = []
logjac_exprs = []
tail_exprs = [] # goes at the end of `forward`
for expr in body.args
# When using `rv = ...` and `logabsdetjac = ...`
if Meta.isexpr(expr, :(=)) && expr.args[1] == :rv
push!(transform_exprs, Expr(:return, expr.args[2]))
push!(tail_exprs, expr)
elseif Meta.isexpr(expr, :(=)) && expr.args[1] == :logabsdetjac
push!(logjac_exprs, Expr(:return, expr.args[2]))
push!(tail_exprs, expr)
else
# If still sharing, add those expressions
issharing = (length(transform_exprs) == 0) && (length(logjac_exprs) == 0)
if issharing
push!(shared_exprs, expr)
else
push!(tail_exprs, expr)
end
end
end
# Add the shared computation
transform_full_exprs = copy(shared_exprs)
append!(transform_full_exprs, transform_exprs)
logjac_full_exprs = copy(shared_exprs)
append!(logjac_full_exprs, logjac_exprs)
# Remove the redundant macro's from the `forward` body
forward_full_exprs = vcat(shared_exprs, tail_exprs)
push!(forward_full_exprs, Expr(:return, :(rv = rv, logabsdetjac = logabsdetjac)))
# HACK: `esc` because the types of the arguments are getting the namespace of `Bijectors`
# because this is the namespace it's expanded in. Not sure how if `esc` everything is
# the best solution.
return esc(quote
function $(Bijectors).transform($(args...))
$(transform_full_exprs...)
end
$bijector_call_expr
function $(Bijectors).logabsdetjac($(args...))
$(logjac_full_exprs...)
end
function $(Bijectors).forward($(args...))
$(forward_full_exprs...)
end
end)
end
@torfjelde
Copy link
Author

torfjelde commented Sep 12, 2020

Example:

julia> using Bijectors, StatsFuns

julia> import Bijectors: @bijector

julia> struct Softplus <: Bijector{0} end

julia> @bijector function f(b::Softplus, x)
           rv = (x isa Real) ? StatsFuns.softplus(x) : StatsFuns.softplus.(x)
           logabsdetjac = (x isa Real) ? -StatsFuns.log1pexp(-x) : -StatsFuns.log1pexp.(-x)
       end

julia> @bijector function f(b::Inverse{<:Softplus}, x)
           rv = (x isa Real) ? StatsFuns.invsoftplus(x) : StatsFuns.invsoftplus.(x)
           logabsdetjac = (x isa Real) ? StatsFuns.log1pexp(-x) : StatsFuns.log1pexp.(-x)
       end

julia> b = Softplus()
Softplus()

julia> ib = inv(b)
Inverse{Softplus,0}(Softplus())

julia> ib(b(-1))
-1.0

julia> Bijectors.forward(b, -1)
(rv = 0.31326168751822286, logabsdetjac = -1.3132616875182228)

julia> Bijectors.forward(ib, [1., 2.])
(rv = [0.541324854612918, 1.854586542131141], logabsdetjac = [0.31326168751822286, 0.1269280110429725])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment