Created
December 28, 2018 02:15
-
-
Save fehiepsi/75dfbea31b993f165f51524776185be6 to your computer and use it in GitHub Desktop.
benchmark pyro nuts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"import pyro\n", | |
"import pyro.distributions as dist\n", | |
"import pyro.poutine as poutine\n", | |
"from pyro.infer.mcmc import MCMC, NUTS\n", | |
"\n", | |
"pyro.enable_validation(True)\n", | |
"pyro.set_rng_seed(1)\n", | |
"torch.set_default_dtype(torch.double)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_categories = 3\n", | |
"num_words = 10\n", | |
"num_data = 100\n", | |
"num_unsup_data = 500\n", | |
"\n", | |
"transition_prior = torch.empty(num_categories).fill_(1.)\n", | |
"emission_prior = torch.empty(num_words).fill_(0.1)\n", | |
"\n", | |
"transition_prob = dist.Dirichlet(transition_prior).sample(torch.Size([num_categories]))\n", | |
"emission_prob = dist.Dirichlet(emission_prior).sample(torch.Size([num_categories]))\n", | |
"\n", | |
"def equilibrium(mc_matrix):\n", | |
" n = mc_matrix.size(0)\n", | |
" return (torch.eye(n) - mc_matrix.t() + 1).inverse().matmul(torch.ones(n))\n", | |
"\n", | |
"start_prob = equilibrium(transition_prob)\n", | |
"\n", | |
"# supervised data\n", | |
"categories, words = [], []\n", | |
"category = dist.Categorical(start_prob).sample()\n", | |
"for t in range(num_data):\n", | |
" if t > 0:\n", | |
" category = dist.Categorical(transition_prob[category]).sample()\n", | |
" word = dist.Categorical(emission_prob[category]).sample()\n", | |
" categories.append(category)\n", | |
" words.append(word)\n", | |
"categories, words = torch.stack(categories), torch.stack(words)\n", | |
"\n", | |
"# unsupervised data\n", | |
"unsup_categories, unsup_words = [], []\n", | |
"category = dist.Categorical(start_prob).sample()\n", | |
"for t in range(num_unsup_data):\n", | |
" if t > 0:\n", | |
" category = dist.Categorical(transition_prob[category]).sample()\n", | |
" word = dist.Categorical(emission_prob[category]).sample()\n", | |
" unsup_categories.append(category)\n", | |
" unsup_words.append(word)\n", | |
"unsup_categories, unsup_words = torch.stack(unsup_categories), torch.stack(unsup_words)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Pyro" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"transition_prior_ = transition_prior.expand(3, 3)\n", | |
"emission_prior_ = emission_prior.expand(3, 10)\n", | |
"\n", | |
"def supervised_hmm(categories, words):\n", | |
" #with pyro.plate(\"prob_plate\", num_categories):\n", | |
" transition_prob = pyro.sample(\"transition_prob\", dist.Dirichlet(transition_prior_).to_event())\n", | |
" emission_prob1 = pyro.sample(\"emission_prob\", dist.Dirichlet(emission_prior_).to_event())\n", | |
"\n", | |
" category = categories[0]\n", | |
" for t in range(words.size(0)):\n", | |
" if t > 0:\n", | |
" category = pyro.sample(\"category_{}\".format(t),\n", | |
" dist.Categorical(transition_prob[category]),\n", | |
" obs=categories[t])\n", | |
" pyro.sample(\"word_{}\".format(t),\n", | |
" dist.Categorical(emission_prob[category]),\n", | |
" obs=words[t])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "62a189f48de646edbc30c31a69826464", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, description='Warmup', max=200, style=ProgressStyle(description_width='init…" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"CPU times: user 1min 6s, sys: 344 ms, total: 1min 6s\n", | |
"Wall time: 1min 6s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"nuts_kernel = NUTS(supervised_hmm, jit_compile=True, ignore_jit_warnings=True)\n", | |
"mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=100).run(categories, words)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Stan" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"data {\n", | |
" int<lower=1> K; // num categories\n", | |
" int<lower=1> V; // num words\n", | |
" int<lower=0> T; // num instances\n", | |
" int<lower=1,upper=V> w[T]; // words\n", | |
" int<lower=1,upper=K> z[T]; // categories\n", | |
" vector<lower=0>[K] alpha; // transit prior\n", | |
" vector<lower=0>[V] beta; // emit prior\n", | |
"}\n", | |
"parameters {\n", | |
" simplex[K] theta[K]; // transit probs\n", | |
" simplex[V] phi[K]; // emit probs\n", | |
"}\n", | |
"model {\n", | |
" for (k in 1:K) \n", | |
" theta[k] ~ dirichlet(alpha);\n", | |
" for (k in 1:K)\n", | |
" phi[k] ~ dirichlet(beta);\n", | |
" for (t in 1:T)\n", | |
" w[t] ~ categorical(phi[z[t]]);\n", | |
" for (t in 2:T)\n", | |
" z[t] ~ categorical(theta[z[t - 1]]);\n", | |
"}\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"import urllib\n", | |
"\n", | |
"import pystan\n", | |
"\n", | |
"url = \"https://raw.githubusercontent.com/stan-dev/example-models/master/misc/hmm/hmm.stan\"\n", | |
"stan_model = urllib.request.urlopen(url).read().decode(\"utf-8\")\n", | |
"print(stan_model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_f55cca0281f9c2977be46938c64909a7 NOW.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 821 ms, sys: 20.1 ms, total: 841 ms\n", | |
"Wall time: 32.2 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"model = pystan.StanModel(model_code=stan_model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:pystan:10 of 100 iterations ended with a divergence (10 %).\n", | |
"WARNING:pystan:Try running with adapt_delta larger than 0.8 to remove the divergences.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 903 ms, sys: 12 ms, total: 915 ms\n", | |
"Wall time: 911 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"data = {\"K\": num_categories, \"V\": num_words, \"T\": num_data, \"alpha\": transition_prior,\n", | |
" \"beta\": emission_prior, \"w\": words + 1, \"z\": categories + 1}\n", | |
"\n", | |
"fit = model.sampling(data, chains=1, iter=200)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment