Last active
February 13, 2020 04:46
-
-
Save fehiepsi/dc64771f87dd1af31f0c79d00ace5819 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pyro.distributions as dist\n", | |
"import pytest\n", | |
"import torch\n", | |
"from jax.random import PRNGKey\n", | |
"import jax\n", | |
"import numpyro\n", | |
"# numpyro.set_platform(\"gpu\")\n", | |
"\n", | |
"from funsor.numpyro.hmm import GaussianHMM as NumGaussianHMM\n", | |
"from funsor.pyro.hmm import GaussianHMM\n", | |
"from funsor.testing import assert_close, random_mvn, randn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%load_ext snakeviz" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"batch_dim, time_dim, obs_dim, hidden_dim = 5, 6000, 3, 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/fehiepsi/jax/jax/lax/lax.py:4627: UserWarning: Explicitly requested dtype <class 'jax.numpy.lax_numpy.int64'> requested in arange is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n", | |
" warnings.warn(msg.format(dtype, fun_name , truncated_dtype))\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"DeviceArray([-45199.684, -45400.703, -45539.02 , -45314.41 , -45346.246], dtype=float32)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"init_shape = (batch_dim,)\n", | |
"trans_mat_shape = trans_mvn_shape = obs_mat_shape = obs_mvn_shape = (batch_dim, time_dim)\n", | |
"init_dist = random_mvn(init_shape, hidden_dim, backend=\"numpy\")\n", | |
"trans_mat = randn(trans_mat_shape + (hidden_dim, hidden_dim), backend=\"numpy\")\n", | |
"trans_dist = random_mvn(trans_mvn_shape, hidden_dim, backend=\"numpy\")\n", | |
"obs_mat = randn(obs_mat_shape + (hidden_dim, obs_dim), backend=\"numpy\")\n", | |
"obs_dist = random_mvn(obs_mvn_shape, obs_dim, backend=\"numpy\")\n", | |
"\n", | |
"actual_dist = NumGaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)\n", | |
"data = obs_dist.sample(PRNGKey(0))\n", | |
"actual_dist.log_prob(data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The warning is a JAX issue. We can fix it upstream." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([-45199.6836, -45400.7031, -45539.0195, -45314.4102, -45346.2461])" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pyro_init_dist = dist.MultivariateNormal(torch.from_numpy(init_dist.loc.copy()), scale_tril=torch.from_numpy(init_dist.scale_tril.copy()))\n", | |
"pyro_trans_dist = dist.MultivariateNormal(torch.from_numpy(trans_dist.loc.copy()), scale_tril=torch.from_numpy(trans_dist.scale_tril.copy()))\n", | |
"pyro_obs_dist = dist.MultivariateNormal(torch.from_numpy(obs_dist.loc.copy()), scale_tril=torch.from_numpy(obs_dist.scale_tril.copy()))\n", | |
"pyro_trans_mat = torch.from_numpy(trans_mat.copy()).float()\n", | |
"pyro_obs_mat = torch.from_numpy(obs_mat.copy()).float()\n", | |
"pyro_data = torch.from_numpy(data.copy())\n", | |
"\n", | |
"expect_dist = GaussianHMM(pyro_init_dist, pyro_trans_mat, pyro_trans_dist, pyro_obs_mat, pyro_obs_dist)\n", | |
"expect_dist.log_prob(pyro_data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([-45199.684, -45400.703, -45539.016, -45314.41 , -45346.25 ],\n", | |
" dtype=float32)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"jit_log_prob = jax.jit(actual_dist.log_prob)\n", | |
"jit_log_prob(data).copy()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.56 ms ± 8.46 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit actual_log_prob = jit_log_prob(data).copy()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"122 ms ± 649 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit expected_log_prob = expect_dist.log_prob(pyro_data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### comparing to pyro GaussianHMM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"expected_dist = dist.GaussianHMM(pyro_init_dist, pyro_trans_mat, pyro_trans_dist, pyro_obs_mat, pyro_obs_dist)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"36.6 ms ± 85.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit expected_log_prob = expected_dist.log_prob(pyro_data)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment