Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Created June 7, 2020 02:55
Show Gist options
  • Save brandonwillard/4fdeedd3d6eafb2b977a453149b6c9d1 to your computer and use it in GitHub Desktop.
Save brandonwillard/4fdeedd3d6eafb2b977a453149b6c9d1 to your computer and use it in GitHub Desktop.
Automatic PyMC3 to RandomVariable conversion prototype
import theano
import theano.tensor as tt
import pymc3 as pm
from warnings import warn
from unittest.mock import patch
from types import SimpleNamespace
from inspect import Signature
from collections import OrderedDict
from symbolic_pymc.theano.ops import RandomVariable
def is_tensor_compatible(x):
try:
_ = tt.as_tensor_variable(x)
return True
except Exception:
return False
class Empty(object):
pass
def pymc3_dist_to_rv(pm_class, dtype=None, ndim_supp=None, ndims_params=None):
"""Create a `RandomVariable` class for a given `pymc3.Distribution` class.
Parameters
----------
pm_class: Distribution
The PyMC3 `Distribution` class.
dtype: Theano dtype
The underlying dtype.
ndim_supp: int
Dimension of the support. This value is used to infer the exact
shape of the support and independent terms from ``dist_params``.
ndims_params: tuple (int)
Number of dimensions of each parameter in ``dist_params``.
"""
assert issubclass(pm_class, pm.Distribution)
name = f'{pm_class.__name__}RVType'
init_sig = Signature.from_callable(pm_class.__init__)
dtype = getattr(pm_class, 'dtype', dtype)
if not dtype and issubclass(pm_class, pm.Discrete):
dtype = "int64"
elif not dtype:
warn('No dtype specified; assuming floatX.')
dtype = theano.config.floatX
ndim_supp = getattr(pm_class, 'ndim_supp', None)
if not ndim_supp:
warn('No ndim_supp specified; assuming scalar support.')
ndim_supp = 0
potential_dist_params = [p for n, p in init_sig.parameters.items()
if n not in ('self', 'kwargs', 'args', 'shape', 'size')]
ndims_params = getattr(pm_class, 'ndims_params', None)
if not ndim_supp:
warn('No ndim_supp specified; assuming scalar parameters from signature.')
ndims_params = (0,) * len(potential_dist_params)
def make_node(self, *args, size=None, rng=None, name=None, **kwargs):
pm_self = Empty()
pm_self.__class__ = pm_class
# with patch('pymc3.distributions.distribution.Distribution.__init__', return_value=None):
_ = pm_class.__init__(pm_self, *args, **kwargs)
pm_self_ordered = OrderedDict(sorted(
i for i in pm_self.__dict__.items()
if is_tensor_compatible(i[1])))
# This will help us remap to a `self`-like namespace when we call
# other borrowed `Distribution` methods.
# TODO: We should probably check for new entries every time, no?
if not hasattr(self, 'dist_args_map'):
self.dist_args_map = OrderedDict([(k, i) for i, k in enumerate(pm_self_ordered.keys())])
dist_args = tuple(pm_self_ordered.values())
return super(type(self), self).make_node(*dist_args, size=size, rng=rng, name=name, **kwargs)
def perform(self, node, inputs, outputs):
rng_out, smpl_out = outputs
args = list(inputs)
# TODO: Temporarily set NumPy's global seed while calling `random`?
rng = args.pop()
size = args.pop()
# Our emulated `self` will have the "sampled" input values, and we'll
# make `draw_values` simply return those concrete values.
pm_self = Empty()
pm_self.__class__ = pm_class
pm_self.__dict__.update(dict(zip(self.dist_args_map.keys(), args)))
# pm_self = SimpleNamespace(**dict(zip(self.dist_args_map.keys(), args)))
def rv_draw_values(*args, **kwargs):
return args
with patch('pymc3.distributions.distribution.draw_values', side_effect=rv_draw_values):
smpl_val = pm_class.random(pm_self, size=size)
smpl_out[0] = smpl_val
def __init__(self):
super(type(self), self).__init__(pm_class.__name__, dtype, ndim_supp, ndims_params, None, inplace=True)
def logp(self):
# TODO: Add an option to produce total log-likelihood by adding
# log-likelihoods in `dist_params`.
out_var = self.owner.outputs[1]
pm_self = Empty()
pm_self.__class__ = pm_class
pm_self.__dict__.update(dict(zip(self.dist_args_order.keys(), out_var.owner.inputs)))
# pm_self = SimpleNamespace(
# **dict(zip(self.dist_args_order.keys(), out_var.owner.inputs)))
self.owner.outputs[-1]
return pm_class.logp(pm_self, out_var)
clsdict = {'logp': logp, 'perform': perform, '__init__': __init__, 'make_node': make_node}
clsdict.update({k: v for k, v in pm_class.__dict__.items()
if k not in ('__module__', 'logp', 'random', '__init__')})
new_rv_type = type(name, (RandomVariable,), clsdict)
# TODO: Add this type to a conversion dict.
NewType = new_rv_type()
return NewType
NormalRV = pymc3_dist_to_rv(pm.Normal)
normal_rv = NormalRV(-10, 1, size=3)
normal_rv.eval()
# array([-10.093141 , -9.57534957, -10.64113658])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment