Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Created August 24, 2021 19:25
Show Gist options
  • Save brandonwillard/2740df9b436446f9ce6fd339ee779bb3 to your computer and use it in GitHub Desktop.
Save brandonwillard/2740df9b436446f9ce6fd339ee779bb3 to your computer and use it in GitHub Desktop.
Illustrations of `theano.scan` structure and manipulations for log-likelihood generation

Computing log-likehood graphs for src_python[:eval never]{Scan}s

Introduction

import numpy as np
import theano
import theano.tensor as tt

import pymc3 as pm

from symbolic_pymc.theano.random_variables import (
    NormalRV,
    MvNormalRV,
    DirichletRV,
    CategoricalRV,
    Observed,
    observed,
)

from symbolic_pymc.theano.opt import ScanArgs
from theano.printing import debugprint as tt_dprint


theano.config.cxx = ""
theano.config.mode = "FAST_COMPILE"
tt.config.compute_test_value = "warn"  # 'ignore'

rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_init_state = rng_state.get_state()
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt

Consider the simple recursive model defined by the Theano graph constructed in Listing input-scan-mit-sot.

def input_step_fn(y_tm1, y_tm2, rng):
    y_tm1.name = 'y_tm1'
    y_tm2.name = 'y_tm2'
    return NormalRV(y_tm1 + y_tm2, 1.0, rng=rng, name='Y_t')


Y_rv, _ = theano.scan(fn=input_step_fn,
                      outputs_info=[
                          {"initial": tt.as_tensor_variable(np.r_[-1.0, 0.0]),
                           "taps": [-1, -2]},
                      ],
                      non_sequences=[rng_tt],
                      n_steps=10)

Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'

input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
tt_dprint(Y_rv)

As the output of Listing input-scan-mit-sot shows us, src_python[:eval never]{Scan} nodes have inner-graphs. The graphs that the function src_python[:eval never]{scan} returns are elements of the outer-graph, and so are the inputs we gave to src_python[:eval never]{scan}. Each graph–inner and outer–has corresponding inputs and outputs. In other words, there are

  • outer-graph inputs (i.e. the Theano objects in src_python[:eval never]{outputs_info}, src_python[:eval never]{non_sequences}, src_python[:eval never]{sequences}, etc.),
  • inner-graph inputs (i.e. Theano objects used to represent the arguments to src_python[:eval never]{input_step_fn}),
  • inner-graph outputs (i.e. the Theano objects representing the return values of src_python[:eval never]{input_step_fn})
  • outer-graph outputs (i.e. the Theano objects returned by src_python[:eval never]{theano.scan}, like src_python[:eval never]{Y_rv})

The src_python[:eval never]{ScanArgs} class nicely collects these objects. Listing input-scan-mit-sot-arguments shows all of these inputs and outputs for our src_python[:eval never]{Scan} result in Listing input-scan-mit-sot. As we can see, each type of input and output gets broken down further into an input “type” (e.g. outer-graph input sequences given by src_python[:eval never]{outer_in_seqs} and outer-graph input non-sequences given by src_python[:eval never]{outer_in_non_seqs}).

print(input_scan_args)

Naturally, the outer-graph inputs connect to the inner-graph inputs and the inner-graph outputs connect to the outer-graph outputs, but we’ll return to this later.

Our objective is to convert the sample-space graph in Listing input-scan-mit-sot into a measure-space graph.

In this instance, a sample-space graph is simply a Theano graph that represents the relationships between random variables. Implementation-wise, a sample-space graph uses src_python[:eval never]{RandomVariable} operators and the output of such a graph can be compiled into a Theano function that computes random samples from the distribution/model implied by the graph.

A measure-space graph represents the calculation of a measure of a random variable or function. When the random variable or function represents a model, the measure can be interpreted as the likelihood (or log-likelihood, if desired).

For the sample-space graph given in Listing input-scan-mit-sot, Listing output-scan-mit-sot illustrates a corresponding measure-space graph.

Y_obs = tt.as_tensor_variable(Y_rv.eval())
Y_obs.name = 'Y_obs'

def output_step_fn(y_t, y_tm1, y_tm2):
    y_t.name = 'y_t'
    y_tm1.name = 'y_tm1'
    y_tm2.name = 'y_tm2'
    logp = pm.Normal.dist(y_tm1 + y_tm2, 1.0).logp(y_t)
    logp.name = 'logp(y_t)'
    return logp

Y_logp, _ = theano.scan(fn=output_step_fn,
                        sequences=[
                            {"input": Y_obs, "taps": [0, -1, -2]}
                        ],
                        outputs_info=[
                            {}
                        ])

output_scan_args = ScanArgs.from_node(Y_logp.owner)

Ultimately, our objective requires us to transform Listing input-scan-mit-sot into the src_python[:eval never]{Scan} in Listing output-scan-mit-sot in a generalized and systematic way.

This further requires us to transform the src_python[:eval never]{Scan} arguments in Listing input-scan-mit-sot-arguments to the src_python[:eval never]{Scan} arguments implied by the measure-space model– shown in Listing output-scan-mit-sot-arguments.

print(output_scan_args)

Cases

Single-tap Argument

What happens if we use only one “tap” (i.e. lagged element)? In Listing output-scan-sit-sot, we modify the original measure-space src_python[:eval never]{Scan} to reflect this.

Sit-Sot

def input_step_fn(y_tm1, rng):
    y_tm1.name = 'y_tm1'
    return NormalRV(y_tm1, 1.0, rng=rng, name='Y_t')

Y_rv, _ = theano.scan(fn=input_step_fn,
                        outputs_info=[
                            {"initial": tt.as_tensor_variable(np.r_[-1.0, 0.0]),
                             "taps": [-1]},
                        ],
                        non_sequences=[rng_tt],
                        n_steps=10)

Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'

input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
print(input_scan_args)

The output of Listing input-scan-sit-sot-arguments shows us that the tap terms (i.e. src_python[:eval never]{y_tm1}) are now in the fields with a src_python[:eval never]{“sit_sot”} suffix. The corresponding measure-space graph is given in Listing output-scan-sit-sot.

Y_obs = tt.as_tensor_variable(Y_rv.eval())
Y_obs.name = 'Y_obs'

def output_step_fn(y_t, y_tm1):
    y_t.name = 'y_t'
    y_tm1.name = 'y_tm1'
    logp = pm.Normal.dist(y_tm1, 1.0).logp(y_t)
    logp.name = 'logp(y_t)'
    return logp

Y_logp, _ = theano.scan(fn=output_step_fn,
                              sequences=[
                                  {"input": Y_obs, "taps": [0, -1]}
                              ],
                              outputs_info=[
                                  {}
                              ])

output_scan_args = ScanArgs.from_node(Y_logp.owner)
print(output_scan_args)

A Variation of Sit-Sot that’s represented as a Mit-sot

The example in Listing input-scan-sit-mit-sot demonstrates how a src_python[:eval never]{Scan} that seemingly specifies a sit-sot is actually represented as a mit-sot. The defining characteristic is the lag/tap order. By making the tap src_python[:eval never]{-3} instead of src_python[:eval never]{-1}, a “single-input-tap” is actually represented as a single “multiple-input-tap”.

def input_step_fn(y_tm1, rng):
    y_tm1.name = 'y_tm1'
    return NormalRV(y_tm1, 1.0, rng=rng, name='Y_t')

Y_rv, _ = theano.scan(fn=input_step_fn,
                        outputs_info=[
                            {"initial": tt.as_tensor_variable(np.r_[-1.0, 0.0]),
                             "taps": [-3]},
                        ],
                        non_sequences=[rng_tt],
                        n_steps=10)

Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'

input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
print(input_scan_args)

Notice–in Listing input-scan-sit-mit-sot-arguments–how our old sit-sot is now suddenly a mit-mot! Clearly, a sit-sot is a type of mit-mot, making the distinction somewhat arbitrary and the implementation a bit redundant.

Multi-tap Dependent Arguments

The situation in Listing input-scan-sit-sot-multi illustrate a src_python[:eval never]{Scan} with multiple inter-dependent output taps. In this instance, we need to either include the other outputs in the log-likelihood-computing src_python[:eval never]{Scan}, or convert both outputs to inputs and make the inner-graph log-likelihood depend on said new input.

def input_step_fn(mu_tm1, y_tm1, rng):
    mu_tm1.name = 'mu_tm1'
    y_tm1.name = 'y_tm1'
    mu = mu_tm1 + y_tm1 + 1
    mu.name = 'mu_t'
    return mu, NormalRV(mu, 1.0, rng=rng, name='Y_t')

(mu_tt, Y_rv), _ = theano.scan(fn=input_step_fn,
                               outputs_info=[
                                   {"initial": tt.as_tensor_variable(np.r_[0.0]),
                                    "taps": [-1]},
                                   {"initial": tt.as_tensor_variable(np.r_[0.0]),
                                    "taps": [-1]},
                               ],
                               non_sequences=[rng_tt],
                               n_steps=10)

mu_tt.name = 'mu_tt'
mu_tt.owner.inputs[0].name = 'mu_all'
Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'

input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
print(input_scan_args)

Listing output-scan-sit-sot-multi provides a log-likelihood-computing src_python[:eval never]{Scan} for Listing input-scan-sit-sot-multi.

Y_obs = tt.as_tensor_variable(Y_rv.eval())
Y_obs.name = 'Y_obs'


def output_step_fn(y_t, y_tm1, mu_tm1):
    mu_tm1.name = 'mu_tm1'
    y_tm1.name = 'y_tm1'
    mu = mu_tm1 + y_tm1 + 1
    mu.name = 'mu_t'
    logp = pm.Normal.dist(mu, 1.0).logp(y_t)
    logp.name = 'logp'
    return mu, logp

(mu_tt, Y_logp), _ = theano.scan(fn=output_step_fn,
                                 sequences=[
                                     {"input": Y_obs, "taps": [0, -1]}
                                 ],
                                 outputs_info=[
                                     {"initial": tt.as_tensor_variable(np.r_[0.0]),
                                      "taps": [-1]},
                                     {}
                                 ])

Y_logp.name = 'Y_logp'
mu_tt.name = 'mu_tt'

output_scan_args = ScanArgs.from_node(Y_logp.owner)

Notice that the src_python[:eval never]{inner_in_sit_sot} src_python[:eval never]{mu_tm1} becomes an element of src_python[:eval never]{inner_in_seqs}.

print(output_scan_args)

Examples

Simple HMM

rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt

N_tt = tt.iscalar("N")
N_tt.tag.test_value = 10
M_tt = tt.iscalar("M")
M_tt.tag.test_value = 2

mus_tt = tt.matrix("mus")
mus_tt.tag.test_value = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype(
    theano.config.floatX
)

sigmas_tt = tt.ones((N_tt,))
sigmas_tt.name = "sigmas"

pi_0_rv = DirichletRV(tt.ones((M_tt,)), rng=rng_tt, name="pi_0")
Gamma_rv = DirichletRV(tt.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma")

S_0_rv = CategoricalRV(pi_0_rv, rng=rng_tt, name="S_0")

def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t, rng):
    gamma_t = Gamma_t[S_tm1]
    gamma_t.name = "gamma_t"
    S_t = CategoricalRV(gamma_t, rng=rng, name="S_t")
    mu_t = mus_t[S_t]
    mu_t.name = "mu[S_t]"
    Y_t = NormalRV(mu_t, sigma_t, rng=rng, name="Y_t")
    return S_t, Y_t

(S_rv, Y_rv), scan_updates = theano.scan(
    fn=scan_fn,
    sequences=[mus_tt, sigmas_tt],
    non_sequences=[Gamma_rv, rng_tt],
    outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {}],
    strict=True,
    name="scan_rv",
)
Y_rv.name = "Y_rv"

input_scan_args = ScanArgs.from_node(Y_rv.owner)
print(input_scan_args)

test_point = {
    M_tt: 2,
    N_tt: 10,
    mus_tt: mus_tt.tag.test_value,
}
Y_obs = tt.as_tensor_variable(Y_rv.eval(test_point))
Y_obs.name = 'Y_obs'


def logp_scan_fn(y_t, mus_t, sigma_t, S_tm1, Gamma_t, rng):
    gamma_t = Gamma_t[S_tm1]
    gamma_t.name = "gamma_t"
    S_t = CategoricalRV(gamma_t, rng=rng, name="S_t")
    mu_t = mus_t[S_t]
    mu_t.name = "mu[S_t]"
    Y_logp_t = pm.Normal.dist(mu_t, sigma_t).logp(y_t)
    Y_logp_t.name = "logp(y_t)"
    return S_t, Y_logp_t

(S_rv, Y_logp), scan_updates = theano.scan(
    fn=logp_scan_fn,
    sequences=[Y_obs, mus_tt, sigmas_tt],
    non_sequences=[Gamma_rv, rng_tt],
    outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {}],
    strict=True,
    name="scan_rv",
)
Y_logp.name = "Y_logp"

output_scan_args = ScanArgs.from_node(Y_logp.owner)
print(output_scan_args)

Implementation

In Listing convert_outer_out_to_in we create a function that converts an outer-graph output to an outer-graph input from a src_python[:eval never]{ScanArg}.

def convert_outer_out_to_in(input_scan_args, var, inner_out_fn=None, output_scan_args=None):
    """Convert outer-graph outputs into outer-graph inputs.

    Parameters
    ----------
    input_scan_args: ScanArgs
        The source scan arguments.
    var: TensorVariable
        The outer-graph output variable that is to be converted into an
        outer-graph input.
    inner_out_fn: function (Optional)
        A function with the signature `(input_scan_args, old_inner_out_var,
        new_outer_input_var output_scan_args)` that produces a new inner-graph
        output.  This can be used to transform the `var`'s
        corresponding inner-graph output, for example.
    input_scan_args: ScanArgs (Optional)
        If this argument is non-`None`, the conversion is applied to the given
        `ScanArgs` and the old `var` output is removed.

    Outputs
    -------
    (ScanArgs, new_outer_input_var)
    A `tuple` containing a `ScanArgs` object with the outer-graph output given by `var` converted
    to an outer-graph input, and a variable representing the new outer-graph input.

    Additionally, some meta information is attached to the output `ScanArgs`
    object that can be used to remove initial values from the outer-graph
    outputs.

    """
    replacing = False
    if output_scan_args is None:
        output_scan_args = ScanArgs.create_empty()
    elif output_scan_args == input_scan_args:
        replacing = True
        # We will not change the input `ScanArgs` in-place
        if output_scan_args is input_scan_args:
            output_scan_args = copy(input_scan_args)

    var_info = input_scan_args.find_among_fields(var, field_filter=lambda x: x.startswith("outer_out"))

    old_inner_out_var = input_scan_args.get_alt_field(var_info, "inner_out")

    if replacing:
        # Remove the old outer-output variable
        # Not sure if this really matters, since we don't use the outer-outputs
        # when building a new `Scan`, but doing it keeps the `ScanArgs` object
        # consistent.
        output_scan_args.remove_from_fields(var, rm_dependents=False)
        output_scan_args.remove_from_fields(old_inner_out_var, rm_dependents=False)

    # Couldn't one do the same with `var_info`?
    inner_out_info = input_scan_args.find_among_fields(old_inner_out_var,
                                                       field_filter=lambda x: x.startswith("inner_out"))

    # Use the index for the specific inner-graph sub-collection to which this
    # variable belongs (e.g. index `1` among the inner-graph sit-sot terms)
    var_idx = inner_out_info.index

    # The old inner-output variable becomes the a new inner-input
    inner_in_var = old_inner_out_var.clone()

    # We need to clone any existing inner-output variables in the `ScanArgs`
    # object that we're mutating and replace references to `old_inner_out_var`
    # with `inner_in_var`.  If we don't, then any other inner-outputs that
    # reference the inner-output that we're replacing will be inconsistent.
    # Instead, we want those other inner-outputs to reference the new
    # inner-input replacement variable.
    from theano.scan_module.scan_utils import clone as tt_clone
    for io_var in list(output_scan_args.inner_outputs):
        io_var_info = output_scan_args.find_among_fields(io_var, field_filter=lambda x: x.startswith("inner_out"))
        io_sub_list = getattr(output_scan_args, io_var_info.name)

        new_io_var, = tt_clone([io_var], replace={old_inner_out_var: inner_in_var})

        io_sub_list[io_var_info.index] = new_io_var

    inner_in_seqs = [inner_in_var]
    if inner_out_info.name.endswith('mit_sot'):
        inner_in_seqs = input_scan_args.inner_in_mit_sot[var_idx] + inner_in_seqs
        if replacing:
            output_scan_args.inner_in_mit_sot.pop(var_idx)
            output_scan_args.outer_in_mit_sot.pop(var_idx)
    elif inner_out_info.name.endswith('sit_sot'):
        inner_in_seqs = [input_scan_args.inner_in_sit_sot[var_idx]] + inner_in_seqs
        if replacing:
            output_scan_args.inner_in_sit_sot.pop(var_idx)
            output_scan_args.outer_in_sit_sot.pop(var_idx)

    taps = [0]
    if inner_out_info.name.endswith('mit_sot'):
        taps = input_scan_args.mit_sot_in_slices[var_idx] + taps
        if replacing:
            output_scan_args.mit_sot_in_slices.pop(var_idx)
    elif inner_out_info.name.endswith('sit_sot'):
        taps = [-1] + taps

    taps, inner_in_seqs = zip(*sorted(zip(taps, inner_in_seqs), key=lambda x: x[0]))

    inner_in_seqs = list(reversed(inner_in_seqs))
    output_scan_args.inner_in_seqs += inner_in_seqs

    taps = np.asarray(taps)
    slice_seqs = zip(-taps, [n if n < 0 else None for n in reversed(taps)])

    # This variable is our new outer-input.
    new_input_var = var.clone()

    var_slices = [new_input_var[b:e] for b, e in slice_seqs]
    n_steps = tt.min([tt.shape(n)[0] for n in var_slices])
    # n_steps -= np.abs(start_idx) + np.abs(stop_idx)

    if output_scan_args.n_steps is None or replacing:
        output_scan_args.n_steps = n_steps

    # output_scan_args.input_taps = taps
    output_scan_args.outer_in_seqs += [v[:n_steps] for v in var_slices]
    output_scan_args.outer_in_nit_sot += [n_steps]

    if inner_out_fn:
        output_scan_args.inner_out_nit_sot += [
            inner_out_fn(input_scan_args, old_inner_out_var, new_input_var, output_scan_args)
        ]

    return (output_scan_args, new_input_var)
from theano.scan_module.scan_op import Scan
from symbolic_pymc.theano.ops import RandomVariable


def get_random_outer_outputs(scan_args):
    """Get the `RandomVariable` outputs of a `Scan` (well, it's `ScanArgs`)."""
    rv_vars = []
    for n, oo in enumerate(scan_args.outer_outputs):
        oo_info = scan_args.find_among_fields(oo)
        io_type = oo_info.name[(oo_info.name.index('_', 6) + 1):]
        inner_out_type = "inner_out_{}".format(io_type)
        io_var = getattr(scan_args, inner_out_type)[oo_info.index]
        if io_var.owner and isinstance(io_var.owner.op, RandomVariable):
            rv_vars.append((n, oo))
    return rv_vars


def inner_out_fn(input_scan_args, old_inner_out_var, new_outer_input_var, output_scan_args):
    from symbolic_pymc.theano.pymc3 import _logp_fn
    Y_t = old_inner_out_var
    y_t = output_scan_args.inner_in_seqs[0]
    logp = _logp_fn(Y_t.owner.op, Y_t, None)(y_t)
    logp.name = "logp(y_t)"
    return logp

Mit-Sot Test

<<input-scan-mit-sot>>
<<input-scan-mit-sot-arguments>>

# Pick the first one for testing
var_idx, var = get_random_outer_outputs(input_scan_args)[0]

test_scan_args, test_Y_rv = convert_outer_out_to_in(input_scan_args, var,
                                                    inner_out_fn=inner_out_fn,
                                                    output_scan_args=input_scan_args)
# The symbolic outer-input variable for our new log-likelihood graph
test_scan_args
output_scan_args

test_scan_op = Scan(test_scan_args.inner_inputs,
                    test_scan_args.inner_outputs,
                    test_scan_args.info)
scan_out = test_scan_op(*test_scan_args.outer_inputs)

if not isinstance(scan_out, list):
    test_scan_op = scan_out.owner
    scan_out = [scan_out]
else:
    test_scan_op = scan_out[0].owner

# The actual input of this should be `Y_obs`
<<output-scan-mit-sot>>
<<output-scan-mit-sot-arguments>>

# tt_dprint(scan_out)
# Y_rv
# tt_dprint(Y_rv)

rng_tt.get_value(borrow=True).set_state(rng_init_state)
res = scan_out[var_idx].eval({test_Y_rv: Y_obs.value})

rng_tt.get_value(borrow=True).set_state(rng_init_state)
exp_res = Y_logp.eval()

assert np.array_equal(res, exp_res)

Full Model Conversion

<<logp-imports>>
<<testing-utils>>

theano.config.cxx = ""
theano.config.mode = "FAST_COMPILE"
tt.config.compute_test_value = "warn"  # 'ignore'

rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_init_state = rng_state.get_state()
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt

N_tt = tt.iscalar("N")
N_tt.tag.test_value = 10
M_tt = tt.iscalar("M")
M_tt.tag.test_value = 2

mus_tt = tt.matrix("mus_t")
mus_tt.tag.test_value = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype(
    theano.config.floatX
)

sigmas_tt = tt.ones((N_tt,))
sigmas_tt.name = "sigmas"

pi_0_rv = DirichletRV(tt.ones((M_tt,)), rng=rng_tt, name="pi_0")
Gamma_rv = DirichletRV(tt.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma")
S_0_rv = CategoricalRV(pi_0_rv, rng=rng_tt, name="S_0")

def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t, rng):
    S_t = CategoricalRV(Gamma_t[0], rng=rng, name="S_t")
    Y_t = NormalRV(mus_t[S_t], sigma_t, rng=rng, name="Y_t")
    return S_t, Y_t

(S_rv, Y_rv), _ = theano.scan(
    fn=scan_fn,
    sequences=[mus_tt, sigmas_tt],
    non_sequences=[Gamma_rv, rng_tt],
    outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {},],
    strict=True,
    name="scan_rv",
)
S_rv.name = "S_rv"
Y_rv.name = "Y_rv"
from theano.gof import Query
from theano.compile import optdb
from theano.gof.opt import SeqOptimizer, EquilibriumOptimizer
from theano.gof.graph import inputs as tt_inputs, io_toposort as tt_io_toposort

from symbolic_pymc.theano.opt import push_out_rvs_from_scan, convert_outer_out_to_in, FunctionGraph, optimize_graph
from symbolic_pymc.theano.pymc3 import _logp_fn


output_vars = (Y_rv,)


def get_random_outer_outputs(scan_args):
    """Get the `RandomVariable` outputs of a `Scan` (well, it's `ScanArgs`)."""
    rv_vars = []
    for n, oo in enumerate(scan_args.outer_outputs):
        oo_info = scan_args.find_among_fields(oo)
        io_type = oo_info.name[(oo_info.name.index("_", 6) + 1) :]
        inner_out_type = "inner_out_{}".format(io_type)
        io_var = getattr(scan_args, inner_out_type)[oo_info.index]
        if io_var.owner and isinstance(io_var.owner.op, RandomVariable):
            rv_vars.append((n, oo))
    return rv_vars


def create_inner_out_logp(input_scan_args, old_inner_out_var, new_inner_in_var, output_scan_args):
    """Create a log-likelihood inner-output.

    This is intended to be use with `get_random_outer_outputs`.

    """
    logp_fn = _logp_fn(old_inner_out_var.owner.op, old_inner_out_var.owner, None)
    logp = logp_fn(new_inner_in_var)
    if new_inner_in_var.name:
        logp.name = "logp({})".format(new_inner_in_var.name)
    return logp


def construct_scan(scan_args):
    scan_op = Scan(scan_args.inner_inputs, scan_args.inner_outputs, scan_args.info)
    scan_out = scan_op(*scan_args.outer_inputs)

    if not isinstance(scan_out, list):
        scan_out = [scan_out]

    return scan_out


def logp(*output_vars):
    """Compute the log-likelihood for a graph.

    Parameters
    ----------
    *output_vars: Tuple[TensorVariable]
        The output of a graph containing `RandomVariable`s.

    Results
    -------
    Dict[TensorVariable, TensorVariable]
        A map from `RandomVariable`s to their log-likelihood graphs.

    """
    model_fgraph = FunctionGraph(tt_inputs(output_vars), output_vars, clone=True)

    canonicalize_opt = optdb.query(Query(include=["canonicalize"]))

    optimizations = SeqOptimizer(canonicalize_opt.copy())
    optimizations.append(EquilibriumOptimizer([push_out_rvs_from_scan], max_use_ratio=10))

    opt_fgraph = optimize_graph(model_fgraph, optimizations, in_place=True)

    rv_to_logp_io = {}
    for node in opt_fgraph.toposort():
        if isinstance(node.op, RandomVariable):
            var = node.default_output()
            rv_to_logp_io[var] = _logp_fn(node.op, var.owner, None)(var)
        elif isinstance(node.op, Scan):
            scan_args = ScanArgs.from_node(node)
            rv_outer_outs = get_random_outer_outputs(scan_args)

            for var_idx, var in rv_outer_outs:
                scan_args = convert_outer_out_to_in(
                    scan_args, var,
                    inner_out_fn=create_inner_out_logp,
                    output_scan_args=scan_args
                )

            logp_scan_out = construct_scan(scan_args)

            for var_idx, var in rv_outer_outs:
                rv_to_logp_io[var] = logp_scan_out[var_idx]

    return rv_to_logp_io

Introduction

import numpy as np
import theano
import theano.tensor as tt

import pymc3 as pm

from symbolic_pymc.theano.random_variables import (
    NormalRV,
    MvNormalRV,
    DirichletRV,
    CategoricalRV,
    Observed,
    observed,
)

from symbolic_pymc.theano.opt import ScanArgs
from theano.printing import debugprint as tt_dprint


theano.config.cxx = ""
theano.config.mode = "FAST_COMPILE"
tt.config.compute_test_value = "warn"  # 'ignore'

rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_init_state = rng_state.get_state()
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt

Consider the simple recursive model defined by the Theano graph constructed in Listing 2.

def input_step_fn(y_tm1, y_tm2, rng):
    y_tm1.name = 'y_tm1'
    y_tm2.name = 'y_tm2'
    return NormalRV(y_tm1 + y_tm2, 1.0, rng=rng, name='Y_t')


Y_rv, _ = theano.scan(fn=input_step_fn,
                      outputs_info=[
                          {"initial": tt.as_tensor_variable(np.r_[-1.0, 0.0]),
                           "taps": [-1, -2]},
                      ],
                      non_sequences=[rng_tt],
                      n_steps=10)

Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'

input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
Listing 2
tt_dprint(Y_rv)
Subtensor{int64::} [id A] 'Y_rv'
 |for{cpu,scan_fn}.0 [id B] 'Y_all'
 | |TensorConstant{10} [id C]
 | |IncSubtensor{Set;:int64:} [id D] ''
 | | |AllocEmpty{dtype='float64'} [id E] ''
 | | | |Elemwise{add,no_inplace} [id F] ''
 | | |   |TensorConstant{10} [id C]
 | | |   |Subtensor{int64} [id G] ''
 | | |     |Shape [id H] ''
 | | |     | |Subtensor{:int64:} [id I] ''
 | | |     |   |TensorConstant{[-1.  0.]} [id J]
 | | |     |   |Constant{2} [id K]
 | | |     |Constant{0} [id L]
 | | |Subtensor{:int64:} [id I] ''
 | | |ScalarFromTensor [id M] ''
 | |   |Subtensor{int64} [id G] ''
 | |rng [id N]
 |Constant{2} [id O]

Inner graphs of the scan ops:

for{cpu,scan_fn}.0 [id B] 'Y_all'
 >normal_rv.1 [id P] 'Y_t'
 > |Elemwise{add,no_inplace} [id Q] ''
 > | |y_tm1 [id R] -> [id D]
 > | |y_tm2 [id S] -> [id D]
 > |TensorConstant{1.0} [id T]
 > |TensorConstant{[]} [id U]
 > |rng_copy [id V] -> [id N]
 >DeepCopyOp [id W] 'rng_copy'
 > |rng_copy [id V] -> [id N]

As the output of Listing 2 shows us, Scan nodes have inner-graphs. The graphs that the function scan returns are elements of the outer-graph, and so are the inputs we gave to scan. Each graph–inner and outer–has corresponding inputs and outputs. In other words, there are

  • outer-graph inputs (i.e. the Theano objects in outputs_info, non_sequences, sequences, etc.),
  • inner-graph inputs (i.e. Theano objects used to represent the arguments to input_step_fn),
  • inner-graph outputs (i.e. the Theano objects representing the return values of input_step_fn)
  • outer-graph outputs (i.e. the Theano objects returned by theano.scan, like Y_rv)

The ScanArgs class nicely collects these objects. Listing 5 shows all of these inputs and outputs for our Scan result in Listing 2. As we can see, each type of input and output gets broken down further into an input "type" (e.g. outer-graph input sequences given by outer_in_seqs and outer-graph input non-sequences given by outer_in_non_seqs).

print(input_scan_args)
Listing 5
ScanArgs(
  n_steps=TensorConstant{10},
  outer_in_seqs=[],
  outer_in_mit_mot=[],
  outer_in_mit_sot=[IncSubtensor{Set;:int64:}.0],
  outer_in_sit_sot=[],
  outer_in_shared=[rng],
  outer_in_nit_sot=[],
  outer_in_non_seqs=[],
  inner_in_seqs=[],
  inner_in_mit_mot=[],
  inner_in_mit_sot=[[y_tm1, y_tm2]],
  inner_in_sit_sot=[],
  inner_in_shared=[rng_copy],
  inner_in_non_seqs=[],
  mit_mot_in_slices=[],
  mit_sot_in_slices=[[-1, -2]],
  inner_out_mit_mot=[],
  inner_out_mit_sot=[Y_t],
  inner_out_sit_sot=[],
  inner_out_nit_sot=[],
  inner_out_shared=[rng_copy],
  mit_mot_out_slices=[],
  outer_out_mit_mot=[],
  outer_out_mit_sot=[Y_all],
  outer_out_sit_sot=[],
  outer_out_nit_sot=[],
  outer_out_shared=[for{cpu,scan_fn}.1])

Naturally, the outer-graph inputs connect to the inner-graph inputs and the inner-graph outputs connect to the outer-graph outputs, but we'll return to this later.

Our objective is to convert the sample-space graph in Listing 2 into a measure-space graph.

In this instance, a sample-space graph is simply a Theano graph that represents the relationships between random variables. Implementation-wise, a sample-space graph uses RandomVariable operators and the output of such a graph can be compiled into a Theano function that computes random samples from the distribution/model implied by the graph.

A measure-space graph represents the calculation of a measure of a random variable or function. When the random variable or function represents a model, the measure can be interpreted as the likelihood (or log-likelihood, if desired).

For the sample-space graph given in Listing 2, Listing 7 illustrates a corresponding measure-space graph.

Y_obs = tt.as_tensor_variable(Y_rv.eval())
Y_obs.name = 'Y_obs'

def output_step_fn(y_t, y_tm1, y_tm2):
    y_t.name = 'y_t'
    y_tm1.name = 'y_tm1'
    y_tm2.name = 'y_tm2'
    logp = pm.Normal.dist(y_tm1 + y_tm2, 1.0).logp(y_t)
    logp.name = 'logp(y_t)'
    return logp

Y_logp, _ = theano.scan(fn=output_step_fn,
                        sequences=[
                            {"input": Y_obs, "taps": [0, -1, -2]}
                        ],
                        outputs_info=[
                            {}
                        ])

output_scan_args = ScanArgs.from_node(Y_logp.owner)
Listing 7

Ultimately, our objective requires us to transform Listing 2 into the Scan in Listing 7 in a generalized and systematic way.

This further requires us to transform the Scan arguments in Listing 5 to the Scan arguments implied by the measure-space model– shown in Listing 8.

print(output_scan_args)
Listing 8
ScanArgs(
  n_steps=Elemwise{minimum,no_inplace}.0,
  outer_in_seqs=[Subtensor{:int64:}.0, Subtensor{:int64:}.0, Subtensor{:int64:}.0],
  outer_in_mit_mot=[],
  outer_in_mit_sot=[],
  outer_in_sit_sot=[],
  outer_in_shared=[],
  outer_in_nit_sot=[Elemwise{minimum,no_inplace}.0],
  outer_in_non_seqs=[],
  inner_in_seqs=[y_t, y_tm1, y_tm2],
  inner_in_mit_mot=[],
  inner_in_mit_sot=[],
  inner_in_sit_sot=[],
  inner_in_shared=[],
  inner_in_non_seqs=[],
  mit_mot_in_slices=[],
  mit_sot_in_slices=[],
  inner_out_mit_mot=[],
  inner_out_mit_sot=[],
  inner_out_sit_sot=[],
  inner_out_nit_sot=[logp(y_t)],
  inner_out_shared=[],
  mit_mot_out_slices=[],
  outer_out_mit_mot=[],
  outer_out_mit_sot=[],
  outer_out_sit_sot=[],
  outer_out_nit_sot=[for{cpu,scan_fn}.0],
  outer_out_shared=[])

Cases

Single-tap Argument

What happens if we use only one "tap" (i.e. lagged element)? In Listing 13, we modify the original measure-space Scan to reflect this.

Sit-Sot

def input_step_fn(y_tm1, rng):
    y_tm1.name = 'y_tm1'
    return NormalRV(y_tm1, 1.0, rng=rng, name='Y_t')

Y_rv, _ = theano.scan(fn=input_step_fn,
                        outputs_info=[
                            {"initial": tt.as_tensor_variable(np.r_[-1.0, 0.0]),
                             "taps": [-1]},
                        ],
                        non_sequences=[rng_tt],
                        n_steps=10)

Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'

input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
print(input_scan_args)
Listing 11
ScanArgs(
  n_steps=TensorConstant{10},
  outer_in_seqs=[],
  outer_in_mit_mot=[],
  outer_in_mit_sot=[],
  outer_in_sit_sot=[IncSubtensor{Set;:int64:}.0],
  outer_in_shared=[rng],
  outer_in_nit_sot=[],
  outer_in_non_seqs=[],
  inner_in_seqs=[],
  inner_in_mit_mot=[],
  inner_in_mit_sot=[],
  inner_in_sit_sot=[y_tm1],
  inner_in_shared=[rng_copy],
  inner_in_non_seqs=[],
  mit_mot_in_slices=[],
  mit_sot_in_slices=[],
  inner_out_mit_mot=[],
  inner_out_mit_sot=[],
  inner_out_sit_sot=[Y_t],
  inner_out_nit_sot=[],
  inner_out_shared=[rng_copy],
  mit_mot_out_slices=[],
  outer_out_mit_mot=[],
  outer_out_mit_sot=[],
  outer_out_sit_sot=[Y_all],
  outer_out_nit_sot=[],
  outer_out_shared=[for{cpu,scan_fn}.1])

The output of Listing 11 shows us that the tap terms (i.e. y_tm1) are now in the fields with a "sit_sot" suffix. The corresponding measure-space graph is given in Listing 13.

Y_obs = tt.as_tensor_variable(Y_rv.eval())
Y_obs.name = 'Y_obs'

def output_step_fn(y_t, y_tm1):
    y_t.name = 'y_t'
    y_tm1.name = 'y_tm1'
    logp = pm.Normal.dist(y_tm1, 1.0).logp(y_t)
    logp.name = 'logp(y_t)'
    return logp

Y_logp, _ = theano.scan(fn=output_step_fn,
                              sequences=[
                                  {"input": Y_obs, "taps": [0, -1]}
                              ],
                              outputs_info=[
                                  {}
                              ])

output_scan_args = ScanArgs.from_node(Y_logp.owner)
Listing 13
print(output_scan_args)
ScanArgs(
  n_steps=Elemwise{minimum,no_inplace}.0,
  outer_in_seqs=[Subtensor{:int64:}.0, Subtensor{:int64:}.0],
  outer_in_mit_mot=[],
  outer_in_mit_sot=[],
  outer_in_sit_sot=[],
  outer_in_shared=[],
  outer_in_nit_sot=[Elemwise{minimum,no_inplace}.0],
  outer_in_non_seqs=[],
  inner_in_seqs=[y_t, y_tm1],
  inner_in_mit_mot=[],
  inner_in_mit_sot=[],
  inner_in_sit_sot=[],
  inner_in_shared=[],
  inner_in_non_seqs=[],
  mit_mot_in_slices=[],
  mit_sot_in_slices=[],
  inner_out_mit_mot=[],
  inner_out_mit_sot=[],
  inner_out_sit_sot=[],
  inner_out_nit_sot=[logp(y_t)],
  inner_out_shared=[],
  mit_mot_out_slices=[],
  outer_out_mit_mot=[],
  outer_out_mit_sot=[],
  outer_out_sit_sot=[],
  outer_out_nit_sot=[for{cpu,scan_fn}.0],
  outer_out_shared=[])

A Variation of Sit-Sot that's represented as a Mit-sot

The example in Listing 16 demonstrates how a Scan that seemingly specifies a sit-sot is actually represented as a mit-sot. The defining characteristic is the lag/tap order. By making the tap -3 instead of -1, a "single-input-tap" is actually represented as a single "multiple-input-tap".

def input_step_fn(y_tm1, rng):
    y_tm1.name = 'y_tm1'
    return NormalRV(y_tm1, 1.0, rng=rng, name='Y_t')

Y_rv, _ = theano.scan(fn=input_step_fn,
                        outputs_info=[
                            {"initial": tt.as_tensor_variable(np.r_[-1.0, 0.0]),
                             "taps": [-3]},
                        ],
                        non_sequences=[rng_tt],
                        n_steps=10)

Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'

input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
Listing 16
print(input_scan_args)
Listing 17
ScanArgs(
  n_steps=TensorConstant{10},
  outer_in_seqs=[],
  outer_in_mit_mot=[],
  outer_in_mit_sot=[IncSubtensor{Set;:int64:}.0],
  outer_in_sit_sot=[],
  outer_in_shared=[rng],
  outer_in_nit_sot=[],
  outer_in_non_seqs=[],
  inner_in_seqs=[],
  inner_in_mit_mot=[],
  inner_in_mit_sot=[[y_tm1]],
  inner_in_sit_sot=[],
  inner_in_shared=[rng_copy],
  inner_in_non_seqs=[],
  mit_mot_in_slices=[],
  mit_sot_in_slices=[[-3]],
  inner_out_mit_mot=[],
  inner_out_mit_sot=[Y_t],
  inner_out_sit_sot=[],
  inner_out_nit_sot=[],
  inner_out_shared=[rng_copy],
  mit_mot_out_slices=[],
  outer_out_mit_mot=[],
  outer_out_mit_sot=[Y_all],
  outer_out_sit_sot=[],
  outer_out_nit_sot=[],
  outer_out_shared=[for{cpu,scan_fn}.1])

Notice–in Listing 17–how our old sit-sot is now suddenly a mit-mot! Clearly, a sit-sot is a type of mit-mot, making the distinction somewhat arbitrary and the implementation a bit redundant.

Multi-tap Dependent Arguments

The situation in Listing 19 illustrate a Scan with multiple inter-dependent output taps. In this instance, we need to either include the other outputs in the log-likelihood-computing Scan, or convert both outputs to inputs and make the inner-graph log-likelihood depend on said new input.

def input_step_fn(mu_tm1, y_tm1, rng):
    mu_tm1.name = 'mu_tm1'
    y_tm1.name = 'y_tm1'
    mu = mu_tm1 + y_tm1 + 1
    mu.name = 'mu_t'
    return mu, NormalRV(mu, 1.0, rng=rng, name='Y_t')

(mu_tt, Y_rv), _ = theano.scan(fn=input_step_fn,
                               outputs_info=[
                                   {"initial": tt.as_tensor_variable(np.r_[0.0]),
                                    "taps": [-1]},
                                   {"initial": tt.as_tensor_variable(np.r_[0.0]),
                                    "taps": [-1]},
                               ],
                               non_sequences=[rng_tt],
                               n_steps=10)

mu_tt.name = 'mu_tt'
mu_tt.owner.inputs[0].name = 'mu_all'
Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'

input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
Listing 19
print(input_scan_args)
ScanArgs(
  n_steps=TensorConstant{10},
  outer_in_seqs=[],
  outer_in_mit_mot=[],
  outer_in_mit_sot=[],
  outer_in_sit_sot=[IncSubtensor{Set;:int64:}.0, IncSubtensor{Set;:int64:}.0],
  outer_in_shared=[rng],
  outer_in_nit_sot=[],
  outer_in_non_seqs=[],
  inner_in_seqs=[],
  inner_in_mit_mot=[],
  inner_in_mit_sot=[],
  inner_in_sit_sot=[mu_tm1, y_tm1],
  inner_in_shared=[rng_copy],
  inner_in_non_seqs=[],
  mit_mot_in_slices=[],
  mit_sot_in_slices=[],
  inner_out_mit_mot=[],
  inner_out_mit_sot=[],
  inner_out_sit_sot=[mu_t, Y_t],
  inner_out_nit_sot=[],
  inner_out_shared=[rng_copy],
  mit_mot_out_slices=[],
  outer_out_mit_mot=[],
  outer_out_mit_sot=[],
  outer_out_sit_sot=[mu_all, Y_all],
  outer_out_nit_sot=[],
  outer_out_shared=[for{cpu,scan_fn}.2])

Listing 22 provides a log-likelihood-computing Scan for Listing 19.

Y_obs = tt.as_tensor_variable(Y_rv.eval())
Y_obs.name = 'Y_obs'


def output_step_fn(y_t, y_tm1, mu_tm1):
    mu_tm1.name = 'mu_tm1'
    y_tm1.name = 'y_tm1'
    mu = mu_tm1 + y_tm1 + 1
    mu.name = 'mu_t'
    logp = pm.Normal.dist(mu, 1.0).logp(y_t)
    logp.name = 'logp'
    return mu, logp

(mu_tt, Y_logp), _ = theano.scan(fn=output_step_fn,
                                 sequences=[
                                     {"input": Y_obs, "taps": [0, -1]}
                                 ],
                                 outputs_info=[
                                     {"initial": tt.as_tensor_variable(np.r_[0.0]),
                                      "taps": [-1]},
                                     {}
                                 ])

Y_logp.name = 'Y_logp'
mu_tt.name = 'mu_tt'

output_scan_args = ScanArgs.from_node(Y_logp.owner)
Listing 22

Notice that the inner_in_sit_sot mu_tm1 becomes an element of inner_in_seqs.

print(output_scan_args)
ScanArgs(
  n_steps=Elemwise{minimum,no_inplace}.0,
  outer_in_seqs=[Subtensor{:int64:}.0, Subtensor{:int64:}.0],
  outer_in_mit_mot=[],
  outer_in_mit_sot=[],
  outer_in_sit_sot=[IncSubtensor{Set;:int64:}.0],
  outer_in_shared=[],
  outer_in_nit_sot=[Elemwise{minimum,no_inplace}.0],
  outer_in_non_seqs=[],
  inner_in_seqs=[Y_obs[t], y_tm1],
  inner_in_mit_mot=[],
  inner_in_mit_sot=[],
  inner_in_sit_sot=[mu_tm1],
  inner_in_shared=[],
  inner_in_non_seqs=[],
  mit_mot_in_slices=[],
  mit_sot_in_slices=[],
  inner_out_mit_mot=[],
  inner_out_mit_sot=[],
  inner_out_sit_sot=[mu_t],
  inner_out_nit_sot=[logp],
  inner_out_shared=[],
  mit_mot_out_slices=[],
  outer_out_mit_mot=[],
  outer_out_mit_sot=[],
  outer_out_sit_sot=[for{cpu,scan_fn}.0],
  outer_out_nit_sot=[Y_logp],
  outer_out_shared=[])

Examples

Simple HMM

rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt

N_tt = tt.iscalar("N")
N_tt.tag.test_value = 10
M_tt = tt.iscalar("M")
M_tt.tag.test_value = 2

mus_tt = tt.matrix("mus")
mus_tt.tag.test_value = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype(
    theano.config.floatX
)

sigmas_tt = tt.ones((N_tt,))
sigmas_tt.name = "sigmas"

pi_0_rv = DirichletRV(tt.ones((M_tt,)), rng=rng_tt, name="pi_0")
Gamma_rv = DirichletRV(tt.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma")

S_0_rv = CategoricalRV(pi_0_rv, rng=rng_tt, name="S_0")

def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t, rng):
    gamma_t = Gamma_t[S_tm1]
    gamma_t.name = "gamma_t"
    S_t = CategoricalRV(gamma_t, rng=rng, name="S_t")
    mu_t = mus_t[S_t]
    mu_t.name = "mu[S_t]"
    Y_t = NormalRV(mu_t, sigma_t, rng=rng, name="Y_t")
    return S_t, Y_t

(S_rv, Y_rv), scan_updates = theano.scan(
    fn=scan_fn,
    sequences=[mus_tt, sigmas_tt],
    non_sequences=[Gamma_rv, rng_tt],
    outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {}],
    strict=True,
    name="scan_rv",
)
Y_rv.name = "Y_rv"

input_scan_args = ScanArgs.from_node(Y_rv.owner)
print(input_scan_args)
ScanArgs(
  n_steps=Elemwise{minimum,no_inplace}.0,
  outer_in_seqs=[Subtensor{:int64:}.0, Subtensor{:int64:}.0],
  outer_in_mit_mot=[],
  outer_in_mit_sot=[],
  outer_in_sit_sot=[IncSubtensor{Set;:int64:}.0],
  outer_in_shared=[rng],
  outer_in_nit_sot=[Elemwise{minimum,no_inplace}.0],
  outer_in_non_seqs=[Gamma],
  inner_in_seqs=[mus[t], sigmas[t]],
  inner_in_mit_mot=[],
  inner_in_mit_sot=[],
  inner_in_sit_sot=[S_0[t-1]],
  inner_in_shared=[rng_copy],
  inner_in_non_seqs=[Gamma_copy],
  mit_mot_in_slices=[],
  mit_sot_in_slices=[],
  inner_out_mit_mot=[],
  inner_out_mit_sot=[],
  inner_out_sit_sot=[S_t],
  inner_out_nit_sot=[Y_t],
  inner_out_shared=[rng_copy],
  mit_mot_out_slices=[],
  outer_out_mit_mot=[],
  outer_out_mit_sot=[],
  outer_out_sit_sot=[for{cpu,scan_rv}.0],
  outer_out_nit_sot=[Y_rv],
  outer_out_shared=[for{cpu,scan_rv}.2])
test_point = {
    M_tt: 2,
    N_tt: 10,
    mus_tt: mus_tt.tag.test_value,
}
Y_obs = tt.as_tensor_variable(Y_rv.eval(test_point))
Y_obs.name = 'Y_obs'


def logp_scan_fn(y_t, mus_t, sigma_t, S_tm1, Gamma_t, rng):
    gamma_t = Gamma_t[S_tm1]
    gamma_t.name = "gamma_t"
    S_t = CategoricalRV(gamma_t, rng=rng, name="S_t")
    mu_t = mus_t[S_t]
    mu_t.name = "mu[S_t]"
    Y_logp_t = pm.Normal.dist(mu_t, sigma_t).logp(y_t)
    Y_logp_t.name = "logp(y_t)"
    return S_t, Y_logp_t

(S_rv, Y_logp), scan_updates = theano.scan(
    fn=logp_scan_fn,
    sequences=[Y_obs, mus_tt, sigmas_tt],
    non_sequences=[Gamma_rv, rng_tt],
    outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {}],
    strict=True,
    name="scan_rv",
)
Y_logp.name = "Y_logp"

output_scan_args = ScanArgs.from_node(Y_logp.owner)
print(output_scan_args)
ScanArgs(
  n_steps=Elemwise{minimum,no_inplace}.0,
  outer_in_seqs=[Subtensor{:int64:}.0, Subtensor{:int64:}.0, Subtensor{:int64:}.0],
  outer_in_mit_mot=[],
  outer_in_mit_sot=[],
  outer_in_sit_sot=[IncSubtensor{Set;:int64:}.0],
  outer_in_shared=[rng],
  outer_in_nit_sot=[Elemwise{minimum,no_inplace}.0],
  outer_in_non_seqs=[Gamma],
  inner_in_seqs=[Y_obs[t], mus[t], sigmas[t]],
  inner_in_mit_mot=[],
  inner_in_mit_sot=[],
  inner_in_sit_sot=[S_0[t-1]],
  inner_in_shared=[rng_copy],
  inner_in_non_seqs=[Gamma_copy],
  mit_mot_in_slices=[],
  mit_sot_in_slices=[],
  inner_out_mit_mot=[],
  inner_out_mit_sot=[],
  inner_out_sit_sot=[S_t],
  inner_out_nit_sot=[logp(y_t)],
  inner_out_shared=[rng_copy],
  mit_mot_out_slices=[],
  outer_out_mit_mot=[],
  outer_out_mit_sot=[],
  outer_out_sit_sot=[for{cpu,scan_rv}.0],
  outer_out_nit_sot=[Y_logp],
  outer_out_shared=[for{cpu,scan_rv}.2])

Implementation

In Listing 31 we create a function that converts an outer-graph output to an outer-graph input from a ScanArg.

def convert_outer_out_to_in(input_scan_args, var, inner_out_fn=None, output_scan_args=None):
    """Convert outer-graph outputs into outer-graph inputs.

    Parameters
    ----------
    input_scan_args: ScanArgs
        The source scan arguments.
    var: TensorVariable
        The outer-graph output variable that is to be converted into an
        outer-graph input.
    inner_out_fn: function (Optional)
        A function with the signature `(input_scan_args, old_inner_out_var,
        new_outer_input_var output_scan_args)` that produces a new inner-graph
        output.  This can be used to transform the `var`'s
        corresponding inner-graph output, for example.
    input_scan_args: ScanArgs (Optional)
        If this argument is non-`None`, the conversion is applied to the given
        `ScanArgs` and the old `var` output is removed.

    Outputs
    -------
    (ScanArgs, new_outer_input_var)
    A `tuple` containing a `ScanArgs` object with the outer-graph output given by `var` converted
    to an outer-graph input, and a variable representing the new outer-graph input.

    Additionally, some meta information is attached to the output `ScanArgs`
    object that can be used to remove initial values from the outer-graph
    outputs.

    """
    replacing = False
    if output_scan_args is None:
        output_scan_args = ScanArgs.create_empty()
    elif output_scan_args == input_scan_args:
        replacing = True
        # We will not change the input `ScanArgs` in-place
        if output_scan_args is input_scan_args:
            output_scan_args = copy(input_scan_args)

    var_info = input_scan_args.find_among_fields(var, field_filter=lambda x: x.startswith("outer_out"))

    old_inner_out_var = input_scan_args.get_alt_field(var_info, "inner_out")

    if replacing:
        # Remove the old outer-output variable
        # Not sure if this really matters, since we don't use the outer-outputs
        # when building a new `Scan`, but doing it keeps the `ScanArgs` object
        # consistent.
        output_scan_args.remove_from_fields(var, rm_dependents=False)
        output_scan_args.remove_from_fields(old_inner_out_var, rm_dependents=False)

    # Couldn't one do the same with `var_info`?
    inner_out_info = input_scan_args.find_among_fields(old_inner_out_var,
                                                       field_filter=lambda x: x.startswith("inner_out"))

    # Use the index for the specific inner-graph sub-collection to which this
    # variable belongs (e.g. index `1` among the inner-graph sit-sot terms)
    var_idx = inner_out_info.index

    # The old inner-output variable becomes the a new inner-input
    inner_in_var = old_inner_out_var.clone()

    # We need to clone any existing inner-output variables in the `ScanArgs`
    # object that we're mutating and replace references to `old_inner_out_var`
    # with `inner_in_var`.  If we don't, then any other inner-outputs that
    # reference the inner-output that we're replacing will be inconsistent.
    # Instead, we want those other inner-outputs to reference the new
    # inner-input replacement variable.
    from theano.scan_module.scan_utils import clone as tt_clone
    for io_var in list(output_scan_args.inner_outputs):
        io_var_info = output_scan_args.find_among_fields(io_var, field_filter=lambda x: x.startswith("inner_out"))
        io_sub_list = getattr(output_scan_args, io_var_info.name)

        new_io_var, = tt_clone([io_var], replace={old_inner_out_var: inner_in_var})

        io_sub_list[io_var_info.index] = new_io_var

    inner_in_seqs = [inner_in_var]
    if inner_out_info.name.endswith('mit_sot'):
        inner_in_seqs = input_scan_args.inner_in_mit_sot[var_idx] + inner_in_seqs
        if replacing:
            output_scan_args.inner_in_mit_sot.pop(var_idx)
            output_scan_args.outer_in_mit_sot.pop(var_idx)
    elif inner_out_info.name.endswith('sit_sot'):
        inner_in_seqs = [input_scan_args.inner_in_sit_sot[var_idx]] + inner_in_seqs
        if replacing:
            output_scan_args.inner_in_sit_sot.pop(var_idx)
            output_scan_args.outer_in_sit_sot.pop(var_idx)

    taps = [0]
    if inner_out_info.name.endswith('mit_sot'):
        taps = input_scan_args.mit_sot_in_slices[var_idx] + taps
        if replacing:
            output_scan_args.mit_sot_in_slices.pop(var_idx)
    elif inner_out_info.name.endswith('sit_sot'):
        taps = [-1] + taps

    taps, inner_in_seqs = zip(*sorted(zip(taps, inner_in_seqs), key=lambda x: x[0]))

    inner_in_seqs = list(reversed(inner_in_seqs))
    output_scan_args.inner_in_seqs += inner_in_seqs

    taps = np.asarray(taps)
    slice_seqs = zip(-taps, [n if n < 0 else None for n in reversed(taps)])

    # This variable is our new outer-input.
    new_input_var = var.clone()

    var_slices = [new_input_var[b:e] for b, e in slice_seqs]
    n_steps = tt.min([tt.shape(n)[0] for n in var_slices])
    # n_steps -= np.abs(start_idx) + np.abs(stop_idx)

    if output_scan_args.n_steps is None or replacing:
        output_scan_args.n_steps = n_steps

    # output_scan_args.input_taps = taps
    output_scan_args.outer_in_seqs += [v[:n_steps] for v in var_slices]
    output_scan_args.outer_in_nit_sot += [n_steps]

    if inner_out_fn:
        output_scan_args.inner_out_nit_sot += [
            inner_out_fn(input_scan_args, old_inner_out_var, new_input_var, output_scan_args)
        ]

    return (output_scan_args, new_input_var)
Listing 31
from theano.scan_module.scan_op import Scan
from symbolic_pymc.theano.ops import RandomVariable


def get_random_outer_outputs(scan_args):
    """Get the `RandomVariable` outputs of a `Scan` (well, it's `ScanArgs`)."""
    rv_vars = []
    for n, oo in enumerate(scan_args.outer_outputs):
        oo_info = scan_args.find_among_fields(oo)
        io_type = oo_info.name[(oo_info.name.index('_', 6) + 1):]
        inner_out_type = "inner_out_{}".format(io_type)
        io_var = getattr(scan_args, inner_out_type)[oo_info.index]
        if io_var.owner and isinstance(io_var.owner.op, RandomVariable):
            rv_vars.append((n, oo))
    return rv_vars


def inner_out_fn(input_scan_args, old_inner_out_var, new_outer_input_var, output_scan_args):
    from symbolic_pymc.theano.pymc3 import _logp_fn
    Y_t = old_inner_out_var
    y_t = output_scan_args.inner_in_seqs[0]
    logp = _logp_fn(Y_t.owner.op, Y_t, None)(y_t)
    logp.name = "logp(y_t)"
    return logp

Mit-Sot Test

def input_step_fn(y_tm1, y_tm2, rng):
    y_tm1.name = 'y_tm1'
    y_tm2.name = 'y_tm2'
    return NormalRV(y_tm1 + y_tm2, 1.0, rng=rng, name='Y_t')


Y_rv, _ = theano.scan(fn=input_step_fn,
                      outputs_info=[
                          {"initial": tt.as_tensor_variable(np.r_[-1.0, 0.0]),
                           "taps": [-1, -2]},
                      ],
                      non_sequences=[rng_tt],
                      n_steps=10)

Y_rv.name = 'Y_rv'
Y_rv.owner.inputs[0].name = 'Y_all'

input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)
print(input_scan_args)

# Pick the first one for testing
var_idx, var = get_random_outer_outputs(input_scan_args)[0]

test_scan_args, test_Y_rv = convert_outer_out_to_in(input_scan_args, var,
                                                    inner_out_fn=inner_out_fn,
                                                    output_scan_args=input_scan_args)
# The symbolic outer-input variable for our new log-likelihood graph
test_scan_args
output_scan_args

test_scan_op = Scan(test_scan_args.inner_inputs,
                    test_scan_args.inner_outputs,
                    test_scan_args.info)
scan_out = test_scan_op(*test_scan_args.outer_inputs)

if not isinstance(scan_out, list):
    test_scan_op = scan_out.owner
    scan_out = [scan_out]
else:
    test_scan_op = scan_out[0].owner

# The actual input of this should be `Y_obs`
Y_obs = tt.as_tensor_variable(Y_rv.eval())
Y_obs.name = 'Y_obs'

def output_step_fn(y_t, y_tm1, y_tm2):
    y_t.name = 'y_t'
    y_tm1.name = 'y_tm1'
    y_tm2.name = 'y_tm2'
    logp = pm.Normal.dist(y_tm1 + y_tm2, 1.0).logp(y_t)
    logp.name = 'logp(y_t)'
    return logp

Y_logp, _ = theano.scan(fn=output_step_fn,
                        sequences=[
                            {"input": Y_obs, "taps": [0, -1, -2]}
                        ],
                        outputs_info=[
                            {}
                        ])

output_scan_args = ScanArgs.from_node(Y_logp.owner)
print(output_scan_args)

# tt_dprint(scan_out)
# Y_rv
# tt_dprint(Y_rv)

rng_tt.get_value(borrow=True).set_state(rng_init_state)
res = scan_out[var_idx].eval({test_Y_rv: Y_obs.value})

rng_tt.get_value(borrow=True).set_state(rng_init_state)
exp_res = Y_logp.eval()

assert np.array_equal(res, exp_res)
ScanArgs(
  n_steps=Elemwise{minimum,no_inplace}.0,
  outer_in_seqs=[Subtensor{:int64:}.0, Subtensor{:int64:}.0, Subtensor{:int64:}.0],
  outer_in_mit_mot=[],
  outer_in_mit_sot=[],
  outer_in_sit_sot=[],
  outer_in_shared=[],
  outer_in_nit_sot=[Elemwise{minimum,no_inplace}.0],
  outer_in_non_seqs=[],
  inner_in_seqs=[y_t, y_tm1, y_tm2],
  inner_in_mit_mot=[],
  inner_in_mit_sot=[],
  inner_in_sit_sot=[],
  inner_in_shared=[],
  inner_in_non_seqs=[],
  mit_mot_in_slices=[],
  mit_sot_in_slices=[],
  inner_out_mit_mot=[],
  inner_out_mit_sot=[],
  inner_out_sit_sot=[],
  inner_out_nit_sot=[logp(y_t)],
  inner_out_shared=[],
  mit_mot_out_slices=[],
  outer_out_mit_mot=[],
  outer_out_mit_sot=[],
  outer_out_sit_sot=[],
  outer_out_nit_sot=[for{cpu,scan_fn}.0],
  outer_out_shared=[])

Full Model Conversion

import numpy as np
import theano
import theano.tensor as tt

import pymc3 as pm

from symbolic_pymc.theano.random_variables import (
    NormalRV,
    MvNormalRV,
    DirichletRV,
    CategoricalRV,
    Observed,
    observed,
)

from symbolic_pymc.theano.opt import ScanArgs
from theano.printing import debugprint as tt_dprint


theano.config.cxx = ""
theano.config.mode = "FAST_COMPILE"
tt.config.compute_test_value = "warn"  # 'ignore'

rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_init_state = rng_state.get_state()
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt
from theano.scan_module.scan_op import Scan
from symbolic_pymc.theano.ops import RandomVariable


def get_random_outer_outputs(scan_args):
    """Get the `RandomVariable` outputs of a `Scan` (well, it's `ScanArgs`)."""
    rv_vars = []
    for n, oo in enumerate(scan_args.outer_outputs):
        oo_info = scan_args.find_among_fields(oo)
        io_type = oo_info.name[(oo_info.name.index('_', 6) + 1):]
        inner_out_type = "inner_out_{}".format(io_type)
        io_var = getattr(scan_args, inner_out_type)[oo_info.index]
        if io_var.owner and isinstance(io_var.owner.op, RandomVariable):
            rv_vars.append((n, oo))
    return rv_vars


def inner_out_fn(input_scan_args, old_inner_out_var, new_outer_input_var, output_scan_args):
    from symbolic_pymc.theano.pymc3 import _logp_fn
    Y_t = old_inner_out_var
    y_t = output_scan_args.inner_in_seqs[0]
    logp = _logp_fn(Y_t.owner.op, Y_t, None)(y_t)
    logp.name = "logp(y_t)"
    return logp


theano.config.cxx = ""
theano.config.mode = "FAST_COMPILE"
tt.config.compute_test_value = "warn"  # 'ignore'

rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_init_state = rng_state.get_state()
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt

N_tt = tt.iscalar("N")
N_tt.tag.test_value = 10
M_tt = tt.iscalar("M")
M_tt.tag.test_value = 2

mus_tt = tt.matrix("mus_t")
mus_tt.tag.test_value = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype(
    theano.config.floatX
)

sigmas_tt = tt.ones((N_tt,))
sigmas_tt.name = "sigmas"

pi_0_rv = DirichletRV(tt.ones((M_tt,)), rng=rng_tt, name="pi_0")
Gamma_rv = DirichletRV(tt.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma")
S_0_rv = CategoricalRV(pi_0_rv, rng=rng_tt, name="S_0")

def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t, rng):
    S_t = CategoricalRV(Gamma_t[0], rng=rng, name="S_t")
    Y_t = NormalRV(mus_t[S_t], sigma_t, rng=rng, name="Y_t")
    return S_t, Y_t

(S_rv, Y_rv), _ = theano.scan(
    fn=scan_fn,
    sequences=[mus_tt, sigmas_tt],
    non_sequences=[Gamma_rv, rng_tt],
    outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {},],
    strict=True,
    name="scan_rv",
)
S_rv.name = "S_rv"
Y_rv.name = "Y_rv"
from theano.gof import Query
from theano.compile import optdb
from theano.gof.opt import SeqOptimizer, EquilibriumOptimizer
from theano.gof.graph import inputs as tt_inputs, io_toposort as tt_io_toposort

from symbolic_pymc.theano.opt import push_out_rvs_from_scan, convert_outer_out_to_in, FunctionGraph, optimize_graph
from symbolic_pymc.theano.pymc3 import _logp_fn


output_vars = (Y_rv,)


def get_random_outer_outputs(scan_args):
    """Get the `RandomVariable` outputs of a `Scan` (well, it's `ScanArgs`)."""
    rv_vars = []
    for n, oo in enumerate(scan_args.outer_outputs):
        oo_info = scan_args.find_among_fields(oo)
        io_type = oo_info.name[(oo_info.name.index("_", 6) + 1) :]
        inner_out_type = "inner_out_{}".format(io_type)
        io_var = getattr(scan_args, inner_out_type)[oo_info.index]
        if io_var.owner and isinstance(io_var.owner.op, RandomVariable):
            rv_vars.append((n, oo))
    return rv_vars


def create_inner_out_logp(input_scan_args, old_inner_out_var, new_inner_in_var, output_scan_args):
    """Create a log-likelihood inner-output.

    This is intended to be use with `get_random_outer_outputs`.

    """
    logp_fn = _logp_fn(old_inner_out_var.owner.op, old_inner_out_var.owner, None)
    logp = logp_fn(new_inner_in_var)
    if new_inner_in_var.name:
        logp.name = "logp({})".format(new_inner_in_var.name)
    return logp


def construct_scan(scan_args):
    scan_op = Scan(scan_args.inner_inputs, scan_args.inner_outputs, scan_args.info)
    scan_out = scan_op(*scan_args.outer_inputs)

    if not isinstance(scan_out, list):
        scan_out = [scan_out]

    return scan_out


def logp(*output_vars):
    """Compute the log-likelihood for a graph.

    Parameters
    ----------
    *output_vars: Tuple[TensorVariable]
        The output of a graph containing `RandomVariable`s.

    Results
    -------
    Dict[TensorVariable, TensorVariable]
        A map from `RandomVariable`s to their log-likelihood graphs.

    """
    model_fgraph = FunctionGraph(tt_inputs(output_vars), output_vars, clone=True)

    canonicalize_opt = optdb.query(Query(include=["canonicalize"]))

    optimizations = SeqOptimizer(canonicalize_opt.copy())
    optimizations.append(EquilibriumOptimizer([push_out_rvs_from_scan], max_use_ratio=10))

    opt_fgraph = optimize_graph(model_fgraph, optimizations, in_place=True)

    rv_to_logp_io = {}
    for node in opt_fgraph.toposort():
        if isinstance(node.op, RandomVariable):
            var = node.default_output()
            rv_to_logp_io[var] = _logp_fn(node.op, var.owner, None)(var)
        elif isinstance(node.op, Scan):
            scan_args = ScanArgs.from_node(node)
            rv_outer_outs = get_random_outer_outputs(scan_args)

            for var_idx, var in rv_outer_outs:
                scan_args = convert_outer_out_to_in(
                    scan_args, var,
                    inner_out_fn=create_inner_out_logp,
                    output_scan_args=scan_args
                )

            logp_scan_out = construct_scan(scan_args)

            for var_idx, var in rv_outer_outs:
                rv_to_logp_io[var] = logp_scan_out[var_idx]

    return rv_to_logp_io
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment