Last active
March 7, 2021 23:41
-
-
Save mschauer/2b3cdc1f811a31607f8727c0f4cd00a0 to your computer and use it in GitHub Desktop.
Get reference measure
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
export basemeasure | |
import MeasureTheory | |
function MeasureTheory.basemeasure(c::ConditionalModel{A,B,M}, x=NamedTuple()) where {A,B,M} | |
_basemeasure(M, Model(c), argvals(c), observations(c), x) | |
end | |
export sourceBasemeasure | |
sourceBasemeasure(m::AbstractModel) = sourceBasemeasure()(Model(m)) | |
function sourceBasemeasure() | |
function(_m::Model) | |
proc(_m, st :: Assign) = :($(st.x) = $(st.rhs)) | |
# proc(_m, st :: Sample) = :(_ℓ += basemeasure($(st.rhs), $(st.x))) | |
proc(_m, st :: Return) = nothing | |
proc(_m, st :: LineNumber) = nothing | |
function proc(_m, st :: Sample) | |
x = st.x | |
rhs = st.rhs | |
@q begin | |
_bm *= basemeasure($rhs) | |
$x = Soss.predict($rhs, $x) | |
end | |
end | |
wrap(kernel) = @q begin | |
_bm = ProductMeasure(()) | |
$kernel | |
return _bm | |
end | |
buildSource(_m, proc, wrap) |> MacroTools.flatten | |
end | |
end | |
@gg M function _basemeasure(_::Type{M}, _m::Model, _args, _data, _pars) where M <: TypeLevel{Module} | |
Expr(:let, | |
Expr(:(=), :M, from_type(M)), | |
type2model(_m) |> sourceBasemeasure() |> loadvals(_args, _data, _pars)) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Needs a patch to
MeasureTheory.jl
because of ambiguity in the constructor