Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Created December 28, 2018 02:15
Show Gist options
  • Save fehiepsi/75dfbea31b993f165f51524776185be6 to your computer and use it in GitHub Desktop.
Save fehiepsi/75dfbea31b993f165f51524776185be6 to your computer and use it in GitHub Desktop.
benchmark pyro nuts
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"import pyro.poutine as poutine\n",
"from pyro.infer.mcmc import MCMC, NUTS\n",
"\n",
"pyro.enable_validation(True)\n",
"pyro.set_rng_seed(1)\n",
"torch.set_default_dtype(torch.double)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"num_categories = 3\n",
"num_words = 10\n",
"num_data = 100\n",
"num_unsup_data = 500\n",
"\n",
"transition_prior = torch.empty(num_categories).fill_(1.)\n",
"emission_prior = torch.empty(num_words).fill_(0.1)\n",
"\n",
"transition_prob = dist.Dirichlet(transition_prior).sample(torch.Size([num_categories]))\n",
"emission_prob = dist.Dirichlet(emission_prior).sample(torch.Size([num_categories]))\n",
"\n",
"def equilibrium(mc_matrix):\n",
" n = mc_matrix.size(0)\n",
" return (torch.eye(n) - mc_matrix.t() + 1).inverse().matmul(torch.ones(n))\n",
"\n",
"start_prob = equilibrium(transition_prob)\n",
"\n",
"# supervised data\n",
"categories, words = [], []\n",
"category = dist.Categorical(start_prob).sample()\n",
"for t in range(num_data):\n",
" if t > 0:\n",
" category = dist.Categorical(transition_prob[category]).sample()\n",
" word = dist.Categorical(emission_prob[category]).sample()\n",
" categories.append(category)\n",
" words.append(word)\n",
"categories, words = torch.stack(categories), torch.stack(words)\n",
"\n",
"# unsupervised data\n",
"unsup_categories, unsup_words = [], []\n",
"category = dist.Categorical(start_prob).sample()\n",
"for t in range(num_unsup_data):\n",
" if t > 0:\n",
" category = dist.Categorical(transition_prob[category]).sample()\n",
" word = dist.Categorical(emission_prob[category]).sample()\n",
" unsup_categories.append(category)\n",
" unsup_words.append(word)\n",
"unsup_categories, unsup_words = torch.stack(unsup_categories), torch.stack(unsup_words)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pyro"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"transition_prior_ = transition_prior.expand(3, 3)\n",
"emission_prior_ = emission_prior.expand(3, 10)\n",
"\n",
"def supervised_hmm(categories, words):\n",
" #with pyro.plate(\"prob_plate\", num_categories):\n",
" transition_prob = pyro.sample(\"transition_prob\", dist.Dirichlet(transition_prior_).to_event())\n",
" emission_prob1 = pyro.sample(\"emission_prob\", dist.Dirichlet(emission_prior_).to_event())\n",
"\n",
" category = categories[0]\n",
" for t in range(words.size(0)):\n",
" if t > 0:\n",
" category = pyro.sample(\"category_{}\".format(t),\n",
" dist.Categorical(transition_prob[category]),\n",
" obs=categories[t])\n",
" pyro.sample(\"word_{}\".format(t),\n",
" dist.Categorical(emission_prob[category]),\n",
" obs=words[t])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "62a189f48de646edbc30c31a69826464",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Warmup', max=200, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"CPU times: user 1min 6s, sys: 344 ms, total: 1min 6s\n",
"Wall time: 1min 6s\n"
]
}
],
"source": [
"%%time\n",
"nuts_kernel = NUTS(supervised_hmm, jit_compile=True, ignore_jit_warnings=True)\n",
"mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=100).run(categories, words)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Stan"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data {\n",
" int<lower=1> K; // num categories\n",
" int<lower=1> V; // num words\n",
" int<lower=0> T; // num instances\n",
" int<lower=1,upper=V> w[T]; // words\n",
" int<lower=1,upper=K> z[T]; // categories\n",
" vector<lower=0>[K] alpha; // transit prior\n",
" vector<lower=0>[V] beta; // emit prior\n",
"}\n",
"parameters {\n",
" simplex[K] theta[K]; // transit probs\n",
" simplex[V] phi[K]; // emit probs\n",
"}\n",
"model {\n",
" for (k in 1:K) \n",
" theta[k] ~ dirichlet(alpha);\n",
" for (k in 1:K)\n",
" phi[k] ~ dirichlet(beta);\n",
" for (t in 1:T)\n",
" w[t] ~ categorical(phi[z[t]]);\n",
" for (t in 2:T)\n",
" z[t] ~ categorical(theta[z[t - 1]]);\n",
"}\n",
"\n"
]
}
],
"source": [
"import urllib\n",
"\n",
"import pystan\n",
"\n",
"url = \"https://raw.githubusercontent.com/stan-dev/example-models/master/misc/hmm/hmm.stan\"\n",
"stan_model = urllib.request.urlopen(url).read().decode(\"utf-8\")\n",
"print(stan_model)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_f55cca0281f9c2977be46938c64909a7 NOW.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 821 ms, sys: 20.1 ms, total: 841 ms\n",
"Wall time: 32.2 s\n"
]
}
],
"source": [
"%%time\n",
"model = pystan.StanModel(model_code=stan_model)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:pystan:10 of 100 iterations ended with a divergence (10 %).\n",
"WARNING:pystan:Try running with adapt_delta larger than 0.8 to remove the divergences.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 903 ms, sys: 12 ms, total: 915 ms\n",
"Wall time: 911 ms\n"
]
}
],
"source": [
"%%time\n",
"data = {\"K\": num_categories, \"V\": num_words, \"T\": num_data, \"alpha\": transition_prior,\n",
" \"beta\": emission_prior, \"w\": words + 1, \"z\": categories + 1}\n",
"\n",
"fit = model.sampling(data, chains=1, iter=200)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment