Last active
May 17, 2024 08:32
-
-
Save torfjelde/e455f1a5c44c65496ecd651ed49881a5 to your computer and use it in GitHub Desktop.
Example of how to _generate_ a Turing.jl model. This can be useful if one is working with very performance critical code where we want to unroll loops of `~` statements, etc. to improve performance.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
julia> using DynamicPPL, Distributions | |
julia> struct NTModel{names,V} | |
nt::NamedTuple{names,V} | |
end | |
julia> model_template = NTModel((a=Normal(0,1), b=Normal(100, 1))) | |
NTModel{(:a, :b), Tuple{Normal{Float64}, Normal{Float64}}}((a = Normal{Float64}(μ=0.0, σ=1.0), b = Normal{Float64}(μ=100.0, σ=1.0))) | |
julia> function define_model_from_template(model_template::NTModel; model_name::Symbol=gensym(:nt_model)) | |
# This will be the body of the model. | |
body = [] | |
# The return values of the model. | |
retvals = Expr(:tuple) | |
# Iterate over the keys and values to do two things. | |
for (name, dist) in pairs(model_template.nt) | |
# 1. Add a `~` statement to the body. | |
push!(body, :($name ~ $dist)) | |
# 2. Add the `name` to the return values. | |
push!(retvals.args, name) | |
end | |
# Construct the actual model. | |
names = keys(model_template.nt) | |
@eval @model function $(model_name)() | |
$(body...) | |
return NamedTuple{$names}($retvals) | |
end | |
end | |
define_model_from_template (generic function with 2 methods) | |
julia> # Define the model from the template. | |
# NOTE: we need to capture the returned function, as this generates a new | |
# function name every time it is called. | |
demo_model = define_model_from_template(model_template) | |
##nt_model#369 (generic function with 2 methods) | |
julia> demo_model()() | |
(a = 1.4612352529081432, b = 100.0517306153272) | |
julia> # We can then use this inside another model using `@submodel`. | |
@model function outer_model(inner_model) | |
@submodel parameters = inner_model | |
x ~ Normal(parameters.a, parameters.b) | |
return (; x, parameters) | |
end | |
outer_model (generic function with 2 methods) | |
julia> model = outer_model(demo_model()) | |
Model{typeof(outer_model), (:inner_model,), (), (), Tuple{Model{var"###nt_model#369", (), (), (), Tuple{}, Tuple{}, DefaultContext}}, Tuple{}, DefaultContext}(outer_model, (inner_model = Model{var"###nt_model#369", (), (), (), Tuple{}, Tuple{}, DefaultContext}(var"##nt_model#369", NamedTuple(), NamedTuple(), DefaultContext()),), NamedTuple(), DefaultContext()) | |
julia> model() | |
(x = -7.9766227799777365, parameters = (a = 0.5489158156933944, b = 99.56385067809187)) | |
julia> # Alternative approach: specify using a macro, though this doesn't work with programmatic generation. | |
""" | |
@model_from_namedtuple exprs... | |
Construct a model from specifications of the form `lhs = rhs`. | |
# Example | |
```julia | |
julia> nt_demo = @model_from_namedtuple a=Normal(0, 1) b=Normal(100, 1) | |
##nt_model#590 (generic function with 2 methods) | |
julia> model = nt_demo(); | |
julia> model() | |
(a = -0.8077731696095273, b = 99.11691965855493) | |
``` | |
Can also handle dependencies if ordered correctly. | |
```julia | |
julia> model_demo_alt = @model_from_namedtuple a=Normal(0, 1) b=Normal(10 * a, 1) | |
##nt_model#747 (generic function with 2 methods) | |
julia> model_demo_alt()() | |
(a = 1.9066394738007792, b = 18.95744677736865) | |
``` | |
""" | |
macro model_from_namedtuple(exprs...) | |
# Every expression should be of the form `lhs = rhs`. | |
@assert all(Base.Fix2(Meta.isexpr, :(=)), exprs) "All expressions should be of the form `lhs = rhs`." | |
# Extract the LHS and RHS of the expressions. | |
lhs_rhs_iter = map(DynamicPPL.getargs_assignment, exprs) | |
model_name = gensym(:nt_model) | |
# This will be the body of the model. | |
body = [] | |
# The return values of the model. | |
retvals = Expr(:tuple) | |
# Iterate over the keys and values to do two things. | |
for (name, dist) in lhs_rhs_iter | |
# 1. Add a `~` statement to the body. | |
push!(body, :($name ~ $dist)) | |
# 2. Add the `name` to the return values. | |
push!(retvals.args, name) | |
end | |
# Construct the actual model. | |
names = map(first, lhs_rhs_iter) | |
expr = :(function $(model_name)() | |
$(body...) | |
return NamedTuple{$names}($retvals) | |
end) | |
return esc(DynamicPPL.model(__module__, __source__, expr, false)) | |
end | |
@model_from_namedtuple | |
julia> model_demo_alt = @model_from_namedtuple a=Normal(0, 1) b=Normal(100, 1) | |
##nt_model#602 (generic function with 2 methods) | |
julia> model_demo_alt()() | |
(a = -1.2205519233973439, b = 101.82300087682167) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment