Created
June 7, 2020 02:55
-
-
Save brandonwillard/4fdeedd3d6eafb2b977a453149b6c9d1 to your computer and use it in GitHub Desktop.
Automatic PyMC3 to RandomVariable conversion prototype
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import theano | |
import theano.tensor as tt | |
import pymc3 as pm | |
from warnings import warn | |
from unittest.mock import patch | |
from types import SimpleNamespace | |
from inspect import Signature | |
from collections import OrderedDict | |
from symbolic_pymc.theano.ops import RandomVariable | |
def is_tensor_compatible(x): | |
try: | |
_ = tt.as_tensor_variable(x) | |
return True | |
except Exception: | |
return False | |
class Empty(object): | |
pass | |
def pymc3_dist_to_rv(pm_class, dtype=None, ndim_supp=None, ndims_params=None): | |
"""Create a `RandomVariable` class for a given `pymc3.Distribution` class. | |
Parameters | |
---------- | |
pm_class: Distribution | |
The PyMC3 `Distribution` class. | |
dtype: Theano dtype | |
The underlying dtype. | |
ndim_supp: int | |
Dimension of the support. This value is used to infer the exact | |
shape of the support and independent terms from ``dist_params``. | |
ndims_params: tuple (int) | |
Number of dimensions of each parameter in ``dist_params``. | |
""" | |
assert issubclass(pm_class, pm.Distribution) | |
name = f'{pm_class.__name__}RVType' | |
init_sig = Signature.from_callable(pm_class.__init__) | |
dtype = getattr(pm_class, 'dtype', dtype) | |
if not dtype and issubclass(pm_class, pm.Discrete): | |
dtype = "int64" | |
elif not dtype: | |
warn('No dtype specified; assuming floatX.') | |
dtype = theano.config.floatX | |
ndim_supp = getattr(pm_class, 'ndim_supp', None) | |
if not ndim_supp: | |
warn('No ndim_supp specified; assuming scalar support.') | |
ndim_supp = 0 | |
potential_dist_params = [p for n, p in init_sig.parameters.items() | |
if n not in ('self', 'kwargs', 'args', 'shape', 'size')] | |
ndims_params = getattr(pm_class, 'ndims_params', None) | |
if not ndim_supp: | |
warn('No ndim_supp specified; assuming scalar parameters from signature.') | |
ndims_params = (0,) * len(potential_dist_params) | |
def make_node(self, *args, size=None, rng=None, name=None, **kwargs): | |
pm_self = Empty() | |
pm_self.__class__ = pm_class | |
# with patch('pymc3.distributions.distribution.Distribution.__init__', return_value=None): | |
_ = pm_class.__init__(pm_self, *args, **kwargs) | |
pm_self_ordered = OrderedDict(sorted( | |
i for i in pm_self.__dict__.items() | |
if is_tensor_compatible(i[1]))) | |
# This will help us remap to a `self`-like namespace when we call | |
# other borrowed `Distribution` methods. | |
# TODO: We should probably check for new entries every time, no? | |
if not hasattr(self, 'dist_args_map'): | |
self.dist_args_map = OrderedDict([(k, i) for i, k in enumerate(pm_self_ordered.keys())]) | |
dist_args = tuple(pm_self_ordered.values()) | |
return super(type(self), self).make_node(*dist_args, size=size, rng=rng, name=name, **kwargs) | |
def perform(self, node, inputs, outputs): | |
rng_out, smpl_out = outputs | |
args = list(inputs) | |
# TODO: Temporarily set NumPy's global seed while calling `random`? | |
rng = args.pop() | |
size = args.pop() | |
# Our emulated `self` will have the "sampled" input values, and we'll | |
# make `draw_values` simply return those concrete values. | |
pm_self = Empty() | |
pm_self.__class__ = pm_class | |
pm_self.__dict__.update(dict(zip(self.dist_args_map.keys(), args))) | |
# pm_self = SimpleNamespace(**dict(zip(self.dist_args_map.keys(), args))) | |
def rv_draw_values(*args, **kwargs): | |
return args | |
with patch('pymc3.distributions.distribution.draw_values', side_effect=rv_draw_values): | |
smpl_val = pm_class.random(pm_self, size=size) | |
smpl_out[0] = smpl_val | |
def __init__(self): | |
super(type(self), self).__init__(pm_class.__name__, dtype, ndim_supp, ndims_params, None, inplace=True) | |
def logp(self): | |
# TODO: Add an option to produce total log-likelihood by adding | |
# log-likelihoods in `dist_params`. | |
out_var = self.owner.outputs[1] | |
pm_self = Empty() | |
pm_self.__class__ = pm_class | |
pm_self.__dict__.update(dict(zip(self.dist_args_order.keys(), out_var.owner.inputs))) | |
# pm_self = SimpleNamespace( | |
# **dict(zip(self.dist_args_order.keys(), out_var.owner.inputs))) | |
self.owner.outputs[-1] | |
return pm_class.logp(pm_self, out_var) | |
clsdict = {'logp': logp, 'perform': perform, '__init__': __init__, 'make_node': make_node} | |
clsdict.update({k: v for k, v in pm_class.__dict__.items() | |
if k not in ('__module__', 'logp', 'random', '__init__')}) | |
new_rv_type = type(name, (RandomVariable,), clsdict) | |
# TODO: Add this type to a conversion dict. | |
NewType = new_rv_type() | |
return NewType | |
NormalRV = pymc3_dist_to_rv(pm.Normal) | |
normal_rv = NormalRV(-10, 1, size=3) | |
normal_rv.eval() | |
# array([-10.093141 , -9.57534957, -10.64113658]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment