import numpy as np
import theano
import theano.tensor as tt
import pymc3 as pm
from symbolic_pymc.theano.random_variables import (
NormalRV,
MvNormalRV,
DirichletRV,
CategoricalRV,
Observed,
observed,
)
from symbolic_pymc.theano.opt import ScanArgs
from theano.printing import debugprint as tt_dprint
theano.config.cxx = ""
theano.config.mode = "FAST_COMPILE"
tt.config.compute_test_value = "warn" # 'ignore'
rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_init_state = rng_state.get_state()
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt
Consider the simple recursive model defined by the Theano graph constructed in Listing input-scan-mit-sot.
def input_step_fn(y_tm1, y_tm2, rng):
y_tm1.name = 'y_tm1'
y_tm2.name = 'y_tm2'
return NormalRV(y_tm1 + y_tm2, 1.0, rng=rng, name='Y_t')
Y_rv, _ = theano.scan(fn=input_step_fn,
outputs_info=[
{"initial": tt.as_tensor_variable(np.r_[-1.0, 0.0]),
"taps": [-1, -2]},
],
non_sequences=[rng_tt],
n_steps=10)
Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'
input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
tt_dprint(Y_rv)
As the output of Listing input-scan-mit-sot shows us, src_python[:eval never]{Scan} nodes have inner-graphs. The graphs that the function src_python[:eval never]{scan} returns are elements of the outer-graph, and so are the inputs we gave to src_python[:eval never]{scan}. Each graph–inner and outer–has corresponding inputs and outputs. In other words, there are
- outer-graph inputs (i.e. the Theano objects in src_python[:eval never]{outputs_info}, src_python[:eval never]{non_sequences}, src_python[:eval never]{sequences}, etc.),
- inner-graph inputs (i.e. Theano objects used to represent the arguments to src_python[:eval never]{input_step_fn}),
- inner-graph outputs (i.e. the Theano objects representing the return values of src_python[:eval never]{input_step_fn})
- outer-graph outputs (i.e. the Theano objects returned by src_python[:eval never]{theano.scan}, like src_python[:eval never]{Y_rv})
The src_python[:eval never]{ScanArgs} class nicely collects these objects. Listing input-scan-mit-sot-arguments shows all of these inputs and outputs for our src_python[:eval never]{Scan} result in Listing input-scan-mit-sot. As we can see, each type of input and output gets broken down further into an input “type” (e.g. outer-graph input sequences given by src_python[:eval never]{outer_in_seqs} and outer-graph input non-sequences given by src_python[:eval never]{outer_in_non_seqs}).
print(input_scan_args)
Naturally, the outer-graph inputs connect to the inner-graph inputs and the inner-graph outputs connect to the outer-graph outputs, but we’ll return to this later.
Our objective is to convert the sample-space graph in Listing input-scan-mit-sot into a measure-space graph.
In this instance, a sample-space graph is simply a Theano graph that represents the relationships between random variables. Implementation-wise, a sample-space graph uses src_python[:eval never]{RandomVariable} operators and the output of such a graph can be compiled into a Theano function that computes random samples from the distribution/model implied by the graph.
A measure-space graph represents the calculation of a measure of a random variable or function. When the random variable or function represents a model, the measure can be interpreted as the likelihood (or log-likelihood, if desired).
For the sample-space graph given in Listing input-scan-mit-sot, Listing output-scan-mit-sot illustrates a corresponding measure-space graph.
Y_obs = tt.as_tensor_variable(Y_rv.eval())
Y_obs.name = 'Y_obs'
def output_step_fn(y_t, y_tm1, y_tm2):
y_t.name = 'y_t'
y_tm1.name = 'y_tm1'
y_tm2.name = 'y_tm2'
logp = pm.Normal.dist(y_tm1 + y_tm2, 1.0).logp(y_t)
logp.name = 'logp(y_t)'
return logp
Y_logp, _ = theano.scan(fn=output_step_fn,
sequences=[
{"input": Y_obs, "taps": [0, -1, -2]}
],
outputs_info=[
{}
])
output_scan_args = ScanArgs.from_node(Y_logp.owner)
Ultimately, our objective requires us to transform Listing input-scan-mit-sot into the src_python[:eval never]{Scan} in Listing output-scan-mit-sot in a generalized and systematic way.
This further requires us to transform the src_python[:eval never]{Scan} arguments in Listing input-scan-mit-sot-arguments to the src_python[:eval never]{Scan} arguments implied by the measure-space model– shown in Listing output-scan-mit-sot-arguments.
print(output_scan_args)
What happens if we use only one “tap” (i.e. lagged element)? In Listing output-scan-sit-sot, we modify the original measure-space src_python[:eval never]{Scan} to reflect this.
def input_step_fn(y_tm1, rng):
y_tm1.name = 'y_tm1'
return NormalRV(y_tm1, 1.0, rng=rng, name='Y_t')
Y_rv, _ = theano.scan(fn=input_step_fn,
outputs_info=[
{"initial": tt.as_tensor_variable(np.r_[-1.0, 0.0]),
"taps": [-1]},
],
non_sequences=[rng_tt],
n_steps=10)
Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'
input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
print(input_scan_args)
The output of Listing input-scan-sit-sot-arguments shows us that the tap terms (i.e. src_python[:eval never]{y_tm1}) are now in the fields with a src_python[:eval never]{“sit_sot”} suffix. The corresponding measure-space graph is given in Listing output-scan-sit-sot.
Y_obs = tt.as_tensor_variable(Y_rv.eval())
Y_obs.name = 'Y_obs'
def output_step_fn(y_t, y_tm1):
y_t.name = 'y_t'
y_tm1.name = 'y_tm1'
logp = pm.Normal.dist(y_tm1, 1.0).logp(y_t)
logp.name = 'logp(y_t)'
return logp
Y_logp, _ = theano.scan(fn=output_step_fn,
sequences=[
{"input": Y_obs, "taps": [0, -1]}
],
outputs_info=[
{}
])
output_scan_args = ScanArgs.from_node(Y_logp.owner)
print(output_scan_args)
The example in Listing input-scan-sit-mit-sot demonstrates how a src_python[:eval never]{Scan} that seemingly specifies a sit-sot is actually represented as a mit-sot. The defining characteristic is the lag/tap order. By making the tap src_python[:eval never]{-3} instead of src_python[:eval never]{-1}, a “single-input-tap” is actually represented as a single “multiple-input-tap”.
def input_step_fn(y_tm1, rng):
y_tm1.name = 'y_tm1'
return NormalRV(y_tm1, 1.0, rng=rng, name='Y_t')
Y_rv, _ = theano.scan(fn=input_step_fn,
outputs_info=[
{"initial": tt.as_tensor_variable(np.r_[-1.0, 0.0]),
"taps": [-3]},
],
non_sequences=[rng_tt],
n_steps=10)
Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'
input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
print(input_scan_args)
Notice–in Listing input-scan-sit-mit-sot-arguments–how our old sit-sot is now suddenly a mit-mot! Clearly, a sit-sot is a type of mit-mot, making the distinction somewhat arbitrary and the implementation a bit redundant.
The situation in Listing input-scan-sit-sot-multi illustrate a src_python[:eval never]{Scan} with multiple inter-dependent output taps. In this instance, we need to either include the other outputs in the log-likelihood-computing src_python[:eval never]{Scan}, or convert both outputs to inputs and make the inner-graph log-likelihood depend on said new input.
def input_step_fn(mu_tm1, y_tm1, rng):
mu_tm1.name = 'mu_tm1'
y_tm1.name = 'y_tm1'
mu = mu_tm1 + y_tm1 + 1
mu.name = 'mu_t'
return mu, NormalRV(mu, 1.0, rng=rng, name='Y_t')
(mu_tt, Y_rv), _ = theano.scan(fn=input_step_fn,
outputs_info=[
{"initial": tt.as_tensor_variable(np.r_[0.0]),
"taps": [-1]},
{"initial": tt.as_tensor_variable(np.r_[0.0]),
"taps": [-1]},
],
non_sequences=[rng_tt],
n_steps=10)
mu_tt.name = 'mu_tt'
mu_tt.owner.inputs[0].name = 'mu_all'
Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'
input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
print(input_scan_args)
Listing output-scan-sit-sot-multi provides a log-likelihood-computing src_python[:eval never]{Scan} for Listing input-scan-sit-sot-multi.
Y_obs = tt.as_tensor_variable(Y_rv.eval())
Y_obs.name = 'Y_obs'
def output_step_fn(y_t, y_tm1, mu_tm1):
mu_tm1.name = 'mu_tm1'
y_tm1.name = 'y_tm1'
mu = mu_tm1 + y_tm1 + 1
mu.name = 'mu_t'
logp = pm.Normal.dist(mu, 1.0).logp(y_t)
logp.name = 'logp'
return mu, logp
(mu_tt, Y_logp), _ = theano.scan(fn=output_step_fn,
sequences=[
{"input": Y_obs, "taps": [0, -1]}
],
outputs_info=[
{"initial": tt.as_tensor_variable(np.r_[0.0]),
"taps": [-1]},
{}
])
Y_logp.name = 'Y_logp'
mu_tt.name = 'mu_tt'
output_scan_args = ScanArgs.from_node(Y_logp.owner)
Notice that the src_python[:eval never]{inner_in_sit_sot} src_python[:eval never]{mu_tm1} becomes an element of src_python[:eval never]{inner_in_seqs}.
print(output_scan_args)
rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt
N_tt = tt.iscalar("N")
N_tt.tag.test_value = 10
M_tt = tt.iscalar("M")
M_tt.tag.test_value = 2
mus_tt = tt.matrix("mus")
mus_tt.tag.test_value = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype(
theano.config.floatX
)
sigmas_tt = tt.ones((N_tt,))
sigmas_tt.name = "sigmas"
pi_0_rv = DirichletRV(tt.ones((M_tt,)), rng=rng_tt, name="pi_0")
Gamma_rv = DirichletRV(tt.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma")
S_0_rv = CategoricalRV(pi_0_rv, rng=rng_tt, name="S_0")
def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t, rng):
gamma_t = Gamma_t[S_tm1]
gamma_t.name = "gamma_t"
S_t = CategoricalRV(gamma_t, rng=rng, name="S_t")
mu_t = mus_t[S_t]
mu_t.name = "mu[S_t]"
Y_t = NormalRV(mu_t, sigma_t, rng=rng, name="Y_t")
return S_t, Y_t
(S_rv, Y_rv), scan_updates = theano.scan(
fn=scan_fn,
sequences=[mus_tt, sigmas_tt],
non_sequences=[Gamma_rv, rng_tt],
outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {}],
strict=True,
name="scan_rv",
)
Y_rv.name = "Y_rv"
input_scan_args = ScanArgs.from_node(Y_rv.owner)
print(input_scan_args)
test_point = {
M_tt: 2,
N_tt: 10,
mus_tt: mus_tt.tag.test_value,
}
Y_obs = tt.as_tensor_variable(Y_rv.eval(test_point))
Y_obs.name = 'Y_obs'
def logp_scan_fn(y_t, mus_t, sigma_t, S_tm1, Gamma_t, rng):
gamma_t = Gamma_t[S_tm1]
gamma_t.name = "gamma_t"
S_t = CategoricalRV(gamma_t, rng=rng, name="S_t")
mu_t = mus_t[S_t]
mu_t.name = "mu[S_t]"
Y_logp_t = pm.Normal.dist(mu_t, sigma_t).logp(y_t)
Y_logp_t.name = "logp(y_t)"
return S_t, Y_logp_t
(S_rv, Y_logp), scan_updates = theano.scan(
fn=logp_scan_fn,
sequences=[Y_obs, mus_tt, sigmas_tt],
non_sequences=[Gamma_rv, rng_tt],
outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {}],
strict=True,
name="scan_rv",
)
Y_logp.name = "Y_logp"
output_scan_args = ScanArgs.from_node(Y_logp.owner)
print(output_scan_args)
In Listing convert_outer_out_to_in we create a function that converts an outer-graph output to an outer-graph input from a src_python[:eval never]{ScanArg}.
def convert_outer_out_to_in(input_scan_args, var, inner_out_fn=None, output_scan_args=None):
"""Convert outer-graph outputs into outer-graph inputs.
Parameters
----------
input_scan_args: ScanArgs
The source scan arguments.
var: TensorVariable
The outer-graph output variable that is to be converted into an
outer-graph input.
inner_out_fn: function (Optional)
A function with the signature `(input_scan_args, old_inner_out_var,
new_outer_input_var output_scan_args)` that produces a new inner-graph
output. This can be used to transform the `var`'s
corresponding inner-graph output, for example.
input_scan_args: ScanArgs (Optional)
If this argument is non-`None`, the conversion is applied to the given
`ScanArgs` and the old `var` output is removed.
Outputs
-------
(ScanArgs, new_outer_input_var)
A `tuple` containing a `ScanArgs` object with the outer-graph output given by `var` converted
to an outer-graph input, and a variable representing the new outer-graph input.
Additionally, some meta information is attached to the output `ScanArgs`
object that can be used to remove initial values from the outer-graph
outputs.
"""
replacing = False
if output_scan_args is None:
output_scan_args = ScanArgs.create_empty()
elif output_scan_args == input_scan_args:
replacing = True
# We will not change the input `ScanArgs` in-place
if output_scan_args is input_scan_args:
output_scan_args = copy(input_scan_args)
var_info = input_scan_args.find_among_fields(var, field_filter=lambda x: x.startswith("outer_out"))
old_inner_out_var = input_scan_args.get_alt_field(var_info, "inner_out")
if replacing:
# Remove the old outer-output variable
# Not sure if this really matters, since we don't use the outer-outputs
# when building a new `Scan`, but doing it keeps the `ScanArgs` object
# consistent.
output_scan_args.remove_from_fields(var, rm_dependents=False)
output_scan_args.remove_from_fields(old_inner_out_var, rm_dependents=False)
# Couldn't one do the same with `var_info`?
inner_out_info = input_scan_args.find_among_fields(old_inner_out_var,
field_filter=lambda x: x.startswith("inner_out"))
# Use the index for the specific inner-graph sub-collection to which this
# variable belongs (e.g. index `1` among the inner-graph sit-sot terms)
var_idx = inner_out_info.index
# The old inner-output variable becomes the a new inner-input
inner_in_var = old_inner_out_var.clone()
# We need to clone any existing inner-output variables in the `ScanArgs`
# object that we're mutating and replace references to `old_inner_out_var`
# with `inner_in_var`. If we don't, then any other inner-outputs that
# reference the inner-output that we're replacing will be inconsistent.
# Instead, we want those other inner-outputs to reference the new
# inner-input replacement variable.
from theano.scan_module.scan_utils import clone as tt_clone
for io_var in list(output_scan_args.inner_outputs):
io_var_info = output_scan_args.find_among_fields(io_var, field_filter=lambda x: x.startswith("inner_out"))
io_sub_list = getattr(output_scan_args, io_var_info.name)
new_io_var, = tt_clone([io_var], replace={old_inner_out_var: inner_in_var})
io_sub_list[io_var_info.index] = new_io_var
inner_in_seqs = [inner_in_var]
if inner_out_info.name.endswith('mit_sot'):
inner_in_seqs = input_scan_args.inner_in_mit_sot[var_idx] + inner_in_seqs
if replacing:
output_scan_args.inner_in_mit_sot.pop(var_idx)
output_scan_args.outer_in_mit_sot.pop(var_idx)
elif inner_out_info.name.endswith('sit_sot'):
inner_in_seqs = [input_scan_args.inner_in_sit_sot[var_idx]] + inner_in_seqs
if replacing:
output_scan_args.inner_in_sit_sot.pop(var_idx)
output_scan_args.outer_in_sit_sot.pop(var_idx)
taps = [0]
if inner_out_info.name.endswith('mit_sot'):
taps = input_scan_args.mit_sot_in_slices[var_idx] + taps
if replacing:
output_scan_args.mit_sot_in_slices.pop(var_idx)
elif inner_out_info.name.endswith('sit_sot'):
taps = [-1] + taps
taps, inner_in_seqs = zip(*sorted(zip(taps, inner_in_seqs), key=lambda x: x[0]))
inner_in_seqs = list(reversed(inner_in_seqs))
output_scan_args.inner_in_seqs += inner_in_seqs
taps = np.asarray(taps)
slice_seqs = zip(-taps, [n if n < 0 else None for n in reversed(taps)])
# This variable is our new outer-input.
new_input_var = var.clone()
var_slices = [new_input_var[b:e] for b, e in slice_seqs]
n_steps = tt.min([tt.shape(n)[0] for n in var_slices])
# n_steps -= np.abs(start_idx) + np.abs(stop_idx)
if output_scan_args.n_steps is None or replacing:
output_scan_args.n_steps = n_steps
# output_scan_args.input_taps = taps
output_scan_args.outer_in_seqs += [v[:n_steps] for v in var_slices]
output_scan_args.outer_in_nit_sot += [n_steps]
if inner_out_fn:
output_scan_args.inner_out_nit_sot += [
inner_out_fn(input_scan_args, old_inner_out_var, new_input_var, output_scan_args)
]
return (output_scan_args, new_input_var)
from theano.scan_module.scan_op import Scan
from symbolic_pymc.theano.ops import RandomVariable
def get_random_outer_outputs(scan_args):
"""Get the `RandomVariable` outputs of a `Scan` (well, it's `ScanArgs`)."""
rv_vars = []
for n, oo in enumerate(scan_args.outer_outputs):
oo_info = scan_args.find_among_fields(oo)
io_type = oo_info.name[(oo_info.name.index('_', 6) + 1):]
inner_out_type = "inner_out_{}".format(io_type)
io_var = getattr(scan_args, inner_out_type)[oo_info.index]
if io_var.owner and isinstance(io_var.owner.op, RandomVariable):
rv_vars.append((n, oo))
return rv_vars
def inner_out_fn(input_scan_args, old_inner_out_var, new_outer_input_var, output_scan_args):
from symbolic_pymc.theano.pymc3 import _logp_fn
Y_t = old_inner_out_var
y_t = output_scan_args.inner_in_seqs[0]
logp = _logp_fn(Y_t.owner.op, Y_t, None)(y_t)
logp.name = "logp(y_t)"
return logp
<<input-scan-mit-sot>>
<<input-scan-mit-sot-arguments>>
# Pick the first one for testing
var_idx, var = get_random_outer_outputs(input_scan_args)[0]
test_scan_args, test_Y_rv = convert_outer_out_to_in(input_scan_args, var,
inner_out_fn=inner_out_fn,
output_scan_args=input_scan_args)
# The symbolic outer-input variable for our new log-likelihood graph
test_scan_args
output_scan_args
test_scan_op = Scan(test_scan_args.inner_inputs,
test_scan_args.inner_outputs,
test_scan_args.info)
scan_out = test_scan_op(*test_scan_args.outer_inputs)
if not isinstance(scan_out, list):
test_scan_op = scan_out.owner
scan_out = [scan_out]
else:
test_scan_op = scan_out[0].owner
# The actual input of this should be `Y_obs`
<<output-scan-mit-sot>>
<<output-scan-mit-sot-arguments>>
# tt_dprint(scan_out)
# Y_rv
# tt_dprint(Y_rv)
rng_tt.get_value(borrow=True).set_state(rng_init_state)
res = scan_out[var_idx].eval({test_Y_rv: Y_obs.value})
rng_tt.get_value(borrow=True).set_state(rng_init_state)
exp_res = Y_logp.eval()
assert np.array_equal(res, exp_res)
<<logp-imports>>
<<testing-utils>>
theano.config.cxx = ""
theano.config.mode = "FAST_COMPILE"
tt.config.compute_test_value = "warn" # 'ignore'
rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_init_state = rng_state.get_state()
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt
N_tt = tt.iscalar("N")
N_tt.tag.test_value = 10
M_tt = tt.iscalar("M")
M_tt.tag.test_value = 2
mus_tt = tt.matrix("mus_t")
mus_tt.tag.test_value = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype(
theano.config.floatX
)
sigmas_tt = tt.ones((N_tt,))
sigmas_tt.name = "sigmas"
pi_0_rv = DirichletRV(tt.ones((M_tt,)), rng=rng_tt, name="pi_0")
Gamma_rv = DirichletRV(tt.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma")
S_0_rv = CategoricalRV(pi_0_rv, rng=rng_tt, name="S_0")
def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t, rng):
S_t = CategoricalRV(Gamma_t[0], rng=rng, name="S_t")
Y_t = NormalRV(mus_t[S_t], sigma_t, rng=rng, name="Y_t")
return S_t, Y_t
(S_rv, Y_rv), _ = theano.scan(
fn=scan_fn,
sequences=[mus_tt, sigmas_tt],
non_sequences=[Gamma_rv, rng_tt],
outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {},],
strict=True,
name="scan_rv",
)
S_rv.name = "S_rv"
Y_rv.name = "Y_rv"
from theano.gof import Query
from theano.compile import optdb
from theano.gof.opt import SeqOptimizer, EquilibriumOptimizer
from theano.gof.graph import inputs as tt_inputs, io_toposort as tt_io_toposort
from symbolic_pymc.theano.opt import push_out_rvs_from_scan, convert_outer_out_to_in, FunctionGraph, optimize_graph
from symbolic_pymc.theano.pymc3 import _logp_fn
output_vars = (Y_rv,)
def get_random_outer_outputs(scan_args):
"""Get the `RandomVariable` outputs of a `Scan` (well, it's `ScanArgs`)."""
rv_vars = []
for n, oo in enumerate(scan_args.outer_outputs):
oo_info = scan_args.find_among_fields(oo)
io_type = oo_info.name[(oo_info.name.index("_", 6) + 1) :]
inner_out_type = "inner_out_{}".format(io_type)
io_var = getattr(scan_args, inner_out_type)[oo_info.index]
if io_var.owner and isinstance(io_var.owner.op, RandomVariable):
rv_vars.append((n, oo))
return rv_vars
def create_inner_out_logp(input_scan_args, old_inner_out_var, new_inner_in_var, output_scan_args):
"""Create a log-likelihood inner-output.
This is intended to be use with `get_random_outer_outputs`.
"""
logp_fn = _logp_fn(old_inner_out_var.owner.op, old_inner_out_var.owner, None)
logp = logp_fn(new_inner_in_var)
if new_inner_in_var.name:
logp.name = "logp({})".format(new_inner_in_var.name)
return logp
def construct_scan(scan_args):
scan_op = Scan(scan_args.inner_inputs, scan_args.inner_outputs, scan_args.info)
scan_out = scan_op(*scan_args.outer_inputs)
if not isinstance(scan_out, list):
scan_out = [scan_out]
return scan_out
def logp(*output_vars):
"""Compute the log-likelihood for a graph.
Parameters
----------
*output_vars: Tuple[TensorVariable]
The output of a graph containing `RandomVariable`s.
Results
-------
Dict[TensorVariable, TensorVariable]
A map from `RandomVariable`s to their log-likelihood graphs.
"""
model_fgraph = FunctionGraph(tt_inputs(output_vars), output_vars, clone=True)
canonicalize_opt = optdb.query(Query(include=["canonicalize"]))
optimizations = SeqOptimizer(canonicalize_opt.copy())
optimizations.append(EquilibriumOptimizer([push_out_rvs_from_scan], max_use_ratio=10))
opt_fgraph = optimize_graph(model_fgraph, optimizations, in_place=True)
rv_to_logp_io = {}
for node in opt_fgraph.toposort():
if isinstance(node.op, RandomVariable):
var = node.default_output()
rv_to_logp_io[var] = _logp_fn(node.op, var.owner, None)(var)
elif isinstance(node.op, Scan):
scan_args = ScanArgs.from_node(node)
rv_outer_outs = get_random_outer_outputs(scan_args)
for var_idx, var in rv_outer_outs:
scan_args = convert_outer_out_to_in(
scan_args, var,
inner_out_fn=create_inner_out_logp,
output_scan_args=scan_args
)
logp_scan_out = construct_scan(scan_args)
for var_idx, var in rv_outer_outs:
rv_to_logp_io[var] = logp_scan_out[var_idx]
return rv_to_logp_io