Created
May 27, 2019 00:20
-
-
Save torfjelde/23603e16e592dd93d87c61bc609d77af to your computer and use it in GitHub Desktop.
Simple parsing of Turing.Model into MetaGraph, allowing visualization of the probabilistic model.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using MacroTools | |
using Turing | |
using LightGraphs, MetaGraphs | |
# Expressions | |
ex1 = quote | |
m(x) = begin | |
# Assumptions | |
σ ~ InverseGamma(2,3) | |
μ ~ Normal(0,sqrt(σ)) | |
# Observations | |
x ~ Normal(μ, sqrt(σ)) | |
end | |
end | |
ex2 = quote | |
m(z) = begin | |
# priors | |
ε ~ q_ε | |
u ~ q_u | |
# observation | |
z ~ h_θ(u, ε) | |
# experimental interface to also track (conditionally) determinstic computations | |
y ≃ 1 + z | |
end | |
end | |
ex3 = quote | |
gdemo(x, y) = begin | |
s ~ InverseGamma(2,3) | |
m ~ Normal(0,sqrt(s)) | |
x ~ Normal(m, sqrt(s)) | |
y ~ Normal(m, sqrt(s)) | |
end | |
end | |
# Parsing and construction of graph | |
ex = ex2 | |
d = MacroTools.splitdef(ex) | |
@info d | |
@info d[:args] | |
@info d[:body] | |
g = MetaDiGraph() | |
sym2vertex = Dict() | |
add_rv!(g, sym, L, R) = begin | |
add_vertex!(g) | |
idx = vertices(g)[end] | |
# properties | |
set_prop!(g, idx, :sym, sym) | |
set_prop!(g, idx, :rv, true) | |
set_prop!(g, idx, :expr, R) | |
return idx | |
end | |
add_determinstic!(g, sym, L, R) = begin | |
add_vertex!(g) | |
idx = vertices(g)[end] | |
# properties | |
set_prop!(g, idx, :sym, sym) | |
set_prop!(g, idx, :rv, false) | |
set_prop!(g, idx, :expr, R) | |
return idx | |
end | |
add_deps!(g, idx, L, R) = begin | |
MacroTools.postwalk(R) do e | |
s = Symbol(e) | |
if s ∈ keys(sym2vertex) | |
add_edge!(g, sym2vertex[s], idx) | |
end | |
end | |
end | |
expr = MacroTools.postwalk(d[:body]) do e | |
if @capture(e, L_ ~ R_) | |
sym = Symbol(L) | |
idx = add_rv!(g, sym, L, R) | |
# check if observed | |
if sym ∈ d[:args] | |
@info "observed" sym | |
set_prop!(g, idx, :observed, true) | |
end | |
sym2vertex[sym] = idx | |
add_deps!(g, idx, L, R) | |
elseif @capture(e, L_ ≃ R_) | |
sym = Symbol(L) | |
idx = add_determinstic!(g, sym, L, R) | |
# check if observed | |
if sym ∈ d[:args] | |
@info "observed" sym | |
set_prop!(g, idx, :observed, true) | |
end | |
sym2vertex[sym] = idx | |
add_deps!(g, idx, L, R) | |
end | |
return e | |
end | |
@info sym2vertex | |
@info [e for e ∈ edges(g)] | |
@info "Graph" g | |
props(g, sym2vertex[:z])[:expr] | |
# visualize | |
using TikzPictures, TikzGraphs, Cairo, Fontconfig | |
# special style for different types of nodes | |
node_styles = Dict() | |
for v ∈ vertices(g) | |
v_props = props(g, v) | |
s = "draw" | |
if get(v_props, :rv, true) | |
s *= ", rounded corners" | |
end | |
if get(v_props, :observed, false) | |
s *= ", fill=green!10" | |
else | |
s *= ", fill=blue!10" | |
end | |
node_styles[v] = s | |
println(s) | |
end | |
p = TikzGraphs.plot( | |
g.graph, | |
[String(props(g, v)[:sym]) for v ∈ vertices(g)], | |
options="scale=2", | |
# node_style="draw, fill=blue!10", | |
node_styles=node_styles | |
) | |
TikzPictures.save(PDF("/tmp/test.pdf"), p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment