Last active
April 13, 2021 20:27
-
-
Save twiecki/a77104299535b64b58953de3c84df56f to your computer and use it in GitHub Desktop.
stochastic_volatility.ipynb
I suppose pm.sample()
already does this?
This looks like something we need to update in PyMC3, as well.
Here's a quick comparison of the timing with and without graph optimizations (the example/model
is taken from this notebook):
fgraph = model.logp.f.maker.fgraph
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 198 µs ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 236 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Doing something like the following will optimize the
FunctionGraph
in roughly the same way thataesara.function
does:Without that step, the JAX function will take the exact form of the log-likelihood graph determined by the
Distribution.logp
implementations (i.e. no CSE, fusions, in-place operations, etc.).