Skip to content

Instantly share code, notes, and snippets.

@mschauer
Last active March 7, 2021 23:41
Show Gist options
  • Save mschauer/2b3cdc1f811a31607f8727c0f4cd00a0 to your computer and use it in GitHub Desktop.
Save mschauer/2b3cdc1f811a31607f8727c0f4cd00a0 to your computer and use it in GitHub Desktop.
Get reference measure
export basemeasure
import MeasureTheory
function MeasureTheory.basemeasure(c::ConditionalModel{A,B,M}, x=NamedTuple()) where {A,B,M}
_basemeasure(M, Model(c), argvals(c), observations(c), x)
end
export sourceBasemeasure
sourceBasemeasure(m::AbstractModel) = sourceBasemeasure()(Model(m))
function sourceBasemeasure()
function(_m::Model)
proc(_m, st :: Assign) = :($(st.x) = $(st.rhs))
# proc(_m, st :: Sample) = :(_ℓ += basemeasure($(st.rhs), $(st.x)))
proc(_m, st :: Return) = nothing
proc(_m, st :: LineNumber) = nothing
function proc(_m, st :: Sample)
x = st.x
rhs = st.rhs
@q begin
_bm *= basemeasure($rhs)
$x = Soss.predict($rhs, $x)
end
end
wrap(kernel) = @q begin
_bm = ProductMeasure(())
$kernel
return _bm
end
buildSource(_m, proc, wrap) |> MacroTools.flatten
end
end
@gg M function _basemeasure(_::Type{M}, _m::Model, _args, _data, _pars) where M <: TypeLevel{Module}
Expr(:let,
Expr(:(=), :M, from_type(M)),
type2model(_m) |> sourceBasemeasure() |> loadvals(_args, _data, _pars))
end
@mschauer
Copy link
Author

mschauer commented Mar 6, 2021

Needs a patch to MeasureTheory.jl

function Base.:*(μ::ProductMeasure{Tuple{}}, ν::N) where {X, N <: AbstractMeasure}
    ProductMeasure((ν,))
end

because of ambiguity in the constructor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment