Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Last active February 13, 2020 04:46
Show Gist options
  • Save fehiepsi/dc64771f87dd1af31f0c79d00ace5819 to your computer and use it in GitHub Desktop.
Save fehiepsi/dc64771f87dd1af31f0c79d00ace5819 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pyro.distributions as dist\n",
"import pytest\n",
"import torch\n",
"from jax.random import PRNGKey\n",
"import jax\n",
"import numpyro\n",
"# numpyro.set_platform(\"gpu\")\n",
"\n",
"from funsor.numpyro.hmm import GaussianHMM as NumGaussianHMM\n",
"from funsor.pyro.hmm import GaussianHMM\n",
"from funsor.testing import assert_close, random_mvn, randn"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext snakeviz"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"batch_dim, time_dim, obs_dim, hidden_dim = 5, 6000, 3, 2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fehiepsi/jax/jax/lax/lax.py:4627: UserWarning: Explicitly requested dtype <class 'jax.numpy.lax_numpy.int64'> requested in arange is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n",
" warnings.warn(msg.format(dtype, fun_name , truncated_dtype))\n"
]
},
{
"data": {
"text/plain": [
"DeviceArray([-45199.684, -45400.703, -45539.02 , -45314.41 , -45346.246], dtype=float32)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"init_shape = (batch_dim,)\n",
"trans_mat_shape = trans_mvn_shape = obs_mat_shape = obs_mvn_shape = (batch_dim, time_dim)\n",
"init_dist = random_mvn(init_shape, hidden_dim, backend=\"numpy\")\n",
"trans_mat = randn(trans_mat_shape + (hidden_dim, hidden_dim), backend=\"numpy\")\n",
"trans_dist = random_mvn(trans_mvn_shape, hidden_dim, backend=\"numpy\")\n",
"obs_mat = randn(obs_mat_shape + (hidden_dim, obs_dim), backend=\"numpy\")\n",
"obs_dist = random_mvn(obs_mvn_shape, obs_dim, backend=\"numpy\")\n",
"\n",
"actual_dist = NumGaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)\n",
"data = obs_dist.sample(PRNGKey(0))\n",
"actual_dist.log_prob(data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The warning is a JAX issue. We can fix it upstream."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-45199.6836, -45400.7031, -45539.0195, -45314.4102, -45346.2461])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pyro_init_dist = dist.MultivariateNormal(torch.from_numpy(init_dist.loc.copy()), scale_tril=torch.from_numpy(init_dist.scale_tril.copy()))\n",
"pyro_trans_dist = dist.MultivariateNormal(torch.from_numpy(trans_dist.loc.copy()), scale_tril=torch.from_numpy(trans_dist.scale_tril.copy()))\n",
"pyro_obs_dist = dist.MultivariateNormal(torch.from_numpy(obs_dist.loc.copy()), scale_tril=torch.from_numpy(obs_dist.scale_tril.copy()))\n",
"pyro_trans_mat = torch.from_numpy(trans_mat.copy()).float()\n",
"pyro_obs_mat = torch.from_numpy(obs_mat.copy()).float()\n",
"pyro_data = torch.from_numpy(data.copy())\n",
"\n",
"expect_dist = GaussianHMM(pyro_init_dist, pyro_trans_mat, pyro_trans_dist, pyro_obs_mat, pyro_obs_dist)\n",
"expect_dist.log_prob(pyro_data)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-45199.684, -45400.703, -45539.016, -45314.41 , -45346.25 ],\n",
" dtype=float32)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jit_log_prob = jax.jit(actual_dist.log_prob)\n",
"jit_log_prob(data).copy()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.56 ms ± 8.46 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"%timeit actual_log_prob = jit_log_prob(data).copy()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"122 ms ± 649 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%timeit expected_log_prob = expect_dist.log_prob(pyro_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### comparing to pyro GaussianHMM"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"expected_dist = dist.GaussianHMM(pyro_init_dist, pyro_trans_mat, pyro_trans_dist, pyro_obs_mat, pyro_obs_dist)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"36.6 ms ± 85.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%timeit expected_log_prob = expected_dist.log_prob(pyro_data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment