Skip to content

Instantly share code, notes, and snippets.

@Orbifold
Created August 23, 2018 17:36
Show Gist options
  • Select an option

  • Save Orbifold/01d62f1ca21393f1f95edbf20905d9a3 to your computer and use it in GitHub Desktop.

Select an option

Save Orbifold/01d62f1ca21393f1f95edbf20905d9a3 to your computer and use it in GitHub Desktop.
Simplistic example in Pyro.ai.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Pyro\n",
"\nGetting to learn the API."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sea\n",
"sea.set()\n",
"sea.set\n",
"%matplotlib inline"
],
"outputs": [],
"execution_count": 18,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"The good thing with Pyro is that the API feels natural:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"normal = dist.Normal(0, 1) \n",
"x = normal.sample() \n",
"print(f\"sample: {x}\")\n",
"print(f\"log probability: {normal.log_prob(x)}\") "
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample: -1.6036932468414307\n",
"log probability: -2.2048544883728027\n"
]
}
],
"execution_count": 4,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"Defining a variable is different than PyMC3. A normally distribute variable 'size' would be:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"x = pyro.sample(\"size\", normal)\n",
"print(x)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor(-1.1752)\n"
]
}
],
"execution_count": 5,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"To fetch the value out of the tensor use `item`:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"print(x.item())"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"-1.1751805543899536\n"
]
}
],
"execution_count": 6,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"## Weather model\n",
"By combining probabilistic variables you can create a simplistic weather model liks so:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"def weather():\n",
" cloudy = pyro.sample('cloudy', dist.Bernoulli(0.3))\n",
" cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'\n",
" mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]\n",
" scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]\n",
" temp = pyro.sample('temp', dist.Normal(mean_temp, scale_temp))\n",
" return cloudy, temp.item()\n"
],
"outputs": [],
"execution_count": 8,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"The logic is straightforward and you can get a few states:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"for _ in range(3):\n",
" print(weather())"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"('cloudy', 51.59914779663086)\n",
"('cloudy', 57.596649169921875)\n",
"('sunny', 79.6128158569336)\n"
]
}
],
"execution_count": 11,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Let's model the sales of ice-cream in function of the weather:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"def ice_cream_sales():\n",
" cloudy, temp = weather()\n",
" expected_sales = 200. if cloudy == 'sunny' and temp > 80.0 else 50.\n",
" ice_cream = pyro.sample('ice_cream', dist.Normal(expected_sales, 10.0))\n",
" return ice_cream"
],
"outputs": [],
"execution_count": 12,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Unlike other PPL frameworks you can get an idea directly from this definition"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"print(f\"Possible sales of ice-cream: {ice_cream_sales():.2f}\")"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Possible sales of ice-cream: 50.69\n"
]
}
],
"execution_count": 23,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"The key-idea behind PPL is to have a marginal distribution of the relevant variable. Marginal means, in simple terms, getting a probability distribution with all de inner-dependencies integrated out.\n",
"\nTo do this in Pyro you first define the posterior:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"posterior = pyro.infer.Importance(ice_cream_sales, num_samples=100)"
],
"outputs": [],
"execution_count": 14,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"and feed it to the `EmpiricalMarginal`:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"marginal = pyro.infer.EmpiricalMarginal(posterior.run())"
],
"outputs": [],
"execution_count": 17,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"To see it in a histogran use something like this:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"plt.hist([marginal().item() for _ in range(1000)], range=(5.0, 100.0))\n",
"plt.title(\"P(sales)\")\n",
"plt.xlabel(\"sales\")\n",
"plt.ylabel(\"#\");"
],
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": [
"iVBORw0KGgoAAAANSUhEUgAAAYIAAAETCAYAAAA7wAFvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAE1dJREFUeJzt3X2QXXV9x/H3JgtI7IJrvWJVEK36tZ0qENSggMkgNlCq+FCt40Mf8KHWtJLqVFBJiVYdaAUfUFHBiDLiOILYikUyig8R0Shixyh+QXzA6thZYiLRaDVk+8c5O17W3QU2Ofdu8n2/ZjJz7rnnnvP93r3Zz/7Oved3RyYnJ5Ek1bVo2AVIkobLIJCk4gwCSSrOIJCk4gwCSSrOIJCk4kaHXYA0LBFxKHAz8I2+1SPAWzNzXbvNxcDZmblpHvv/LPD2zLx0Ho99IHAe8IzM3Hl3Hy/dHQaBqvtlZh4+dSMiHgBsioivAo8AfjafENhVmfk/EXE98FLg7YM+vmoxCKQ+mfmjiLgJeDjwWuCZABFxP+ADwH3aTT+RmWsi4p7A+e329wa2Ac/JzOzfb0Q8HjgbuCewE1ibmVfMtt92+b3Axoh4T2b+upuOJd8jkO4gIh4HPBT4X2D/vtHAi4DvZuZS4FjgYRFxIHAisDUzj8rMhwNfAf5h2j7HgfcBz28f/xTg/Ig4ZI79kpk/An4MHN1p0yrPEYGq2z8ivt4ujwK3As8FDgK+07fdJ4H/an95fwo4PTN/BlwaEd+NiH+kCZAVwLXTjvE44A+Aj0XE1LpJ4FFz7HfKzUAAn9kNvUozMghU3R3eI5gSEU8HFk/dzsyvRMSDgeOB42hO2TwVOAx4Mc15/EuAnwIPnra7xcANmbmsb//3ByYy8zcz7Tczv9j32Nt3T6vSzDw1JM3sRuAhUzci4ixgTWZ+DDgV+CbN+wIrgYsy871AAk+mL0BaX6I55fOEdl+HAzcB959jv1MeAnx797cn/ZZBIM2gfW/glxHxR+2qtwCHR8Qm4KvA94APAW8C/q49vfRp4Gs0p4j69zUBPAP494j4b+BimvcLfjDHfomIg4D7Atd02as04jTU0swi4jnAMZn50iEdfy3N6aN3DOP4qsMRgTSLzLwE+P2IeOSgjx0RBwNLgXcN+tiqxxGBJBXniECSijMIJKm4zq4jiIjFwAU0F8NMAi8BfgVc1N7eBKzKzJ0RcSZwErADWJ2ZG+fa98TEtrLns8bHl7Bly/ZhlzE09m//9j///nu9sZGZ1nc5IngyQGYeDZwBvAE4FzgjM4+lmeXx5IhYCiwHlgHPBvyExBxGR6d/RL0W+7f/yrrqv7MgaC+QeXF780HAVuBI4HPtuitprqY8BlifmZOZeQswGhG9ruqSJN1Rp1NMZOaOiHg/8DTgL4AnZebUaZ1twIHAAcDmvodNrZ+Ybb/j40tK/2XQ640Nu4Shsn/7r6yL/jufaygz/zoiTgO+DOzfd9cYzSjhtnZ5+vpZVT5H2OuNMTGxbdhlDI3927/9z7//2UKks1NDEfH8iHhVe3M7zRzsX42IFe26E4ENNJfPr4yIRe0MjIsy89au6pIk3VGXI4KPAu+LiM8D+wCrgRuACyJi33b50sy8PSI20EzduwhY1WFNkqRpOguCzPwF8KwZ7lo+w7ZrgbVd1SJJmp0XlElScQaBJBVnEEhScX5VpfZKp5x19VCOu+7044ZyXGlXOCKQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqbrSLnUbEPsA64FBgP+D1wA+BK4Cb2s3Oz8wPR8SZwEnADmB1Zm7soiZJ0sw6CQLgecDmzHx+RNwb+DrwOuDczDxnaqOIWAosB5YBBwOXAY/pqCZJ0gy6CoKPAJe2yyM0f+0fCUREnEwzKlgNHAOsz8xJ4JaIGI2IXmZOdFSXJGmaToIgM38OEBFjNIFwBs0pogsz87qIeA1wJrAV2Nz30G3AgcCcQTA+voTR0cVdlL5H6PXGhl3CUC3k/gdR20LufxDsf/f339WIgIg4GLgceGdmXhIR98rMre3dlwPnAf8B9Hc1RhMOc9qyZfvuLneP0euNMTGxbdhlDM1C77/r2hZ6/12z/13rf7YQ6eRTQxFxELAeOC0z17Wrr4qIx7bLTwSuA64BVkbEoog4BFiUmbd2UZMkaWZdjQheDYwDayJiTbvu5cCbI+I3wE+AF2fmbRGxAbiWJpRWdVSPJGkWXb1HcCpw6gx3HT3DtmuBtV3UIUm6c529RyBVdMpZVw/t2OtOP25ox9aezSuLJak4g0CSijMIJKk4g0CSijMIJKk4g0CSijMIJKk4g0CSijMIJKk4g0CSijMIJKk4g0CSijMIJKk4g0CSijMIJKk4g0CSijMIJKk4g0CSijMIJKk4g0CSijMIJKk4g0CSijMIJKk4g0CSijMIJKk4g0CSijMIJKm40S52GhH7AOuAQ4H9gNcD3wIuAiaBTcCqzNwZEWcCJwE7gNWZubGLmiRJM+tqRPA8YHNmHgucALwdOBc4o103ApwcEUuB5cAy4NnAOzqqR5I0i05GBMBHgEvb5RGav/aPBD7XrrsS+FMggfWZOQncEhGjEdHLzIm5dj4+voTR0cXdVL4H6PXGhl3CUFXvfzZVnpcqfc6mi/47CYLM/DlARIzRBMIZwJvaX/gA24ADgQOAzX0PnVo/ZxBs2bJ9d5e8x+j1xpiY2DbsMoamev9zqfC8VP/572r/s4VIZ28WR8TBwGeAizPzEmBn391jwFbgtnZ5+npJ0oB0EgQRcRCwHjgtM9e1q6+PiBXt8onABuAaYGVELIqIQ4BFmXlrFzVJkmbW1XsErwbGgTURsaZddyrwtojYF7gBuDQzb4+IDcC1NKG0qqN6JEmz6Oo9glNpfvFPt3yGbdcCa7uoQ5J057ygTJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqTiDQJKKMwgkqbjR2e6IiBcCXwA+kJmPbddtnFq+KyJiGXB2Zq6IiCOAK4Cb2rvPz8wPR8SZwEnADmB1Zm6cZy+SpHmYNQiAfYA1wJ9ExGeBbwIHRcSjgG9k5uRcO46IVwLPB37RrjoSODczz+nbZimwHFgGHAxcBjxmfq1IkuZj1iDIzPOB8yPiepq/2B8FPBk4FXgkcGcjg5uBpwMXt7ePBCIiTqYZFawGjgHWt6FyS0SMRkQvMyfm2vH4+BJGRxffaXN7q15vbNglDFX1/mdT5Xmp0udsuuh/rlNDXwNuBO5F80t/E3BrZr7gruw4My+LiEP7Vm0ELszM6yLiNcCZwFZgc98224ADgTmDYMuW7XelhL1SrzfGxMS2YZcxNNX7n0uF56X6z39X+58tRGZ9szgzlwL/QnOK6ATgIuDhEXF5e9rn7ro8M6+bWgaOAG4D+isbowkHSdKAzPmpocy8EdiUmadl5kk0bx6/DPj+PI51VURMnU56InAdcA2wMiIWRcQhwKLMvHUe+5YkzdNcbxYDkJknzLD8w3kc6++B8yLiN8BPgBdn5m0RsQG4liaUVs1jv5KkXXCnQbArMvP7wFHt8teAo2fYZi2wtss6JEmz84IySSrOIJCk4gwCSSrOIJCk4gwCSSrOIJCk4gwCSSrOIJCk4gwCSSrOIJCk4gwCSSrOIJCk4gwCSSrOIJCk4gwCSSrOIJCk4gwCSSqu028oU22nnHX1sEuQdBc4IpCk4gwCSSrOIJCk4gwCSSrOIJCk4gwCSSrOIJCk4gwCSSrOIJCk4jq9sjgilgFnZ+aKiHgocBEwCWwCVmXmzog4EzgJ2AGszsyNXdYkSbqjzkYEEfFK4ELgHu2qc4EzMvNYYAQ4OSKWAsuBZcCzgXd0VY8kaWZdnhq6GXh63+0jgc+1y1cCxwPHAOszczIzbwFGI6LXYU2SpGk6OzWUmZdFxKF9q0Yyc7Jd3gYcCBwAbO7bZmr9xFz7Hh9fwujo4t1Y7Z6l1xsbdglagKq8Lqr0OZsu+h/k7KM7+5bHgK3Abe3y9PVz2rJl++6tbA/S640xMbFt2GVoAarwuqj++t/V/mcLkUF+auj6iFjRLp8IbACuAVZGxKKIOARYlJm3DrAmSSpvkCOCVwAXRMS+wA3ApZl5e0RsAK6lCaVVA6xHkkTHQZCZ3weOapdvpPmE0PRt1gJru6xDkjQ7LyiTpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqbnTQB4yIrwG3tTe/B7wbeCuwA1ifma8ddE2SVNlAgyAi7gGMZOaKvnVfB54BfBf4REQckZnXD7IuSaps0COCw4AlEbG+PfZaYL/MvBkgIq4CjgcMAkkakEEHwXbgTcCFwMOAK4GtffdvAx5yZzsZH1/C6OjiTgrcE/R6Y8MuQQtQlddFlT5n00X/gw6CG4HvZOYkcGNE/Ay4d9/9Y9wxGGa0Zcv2jspb+Hq9MSYmtg27DC1AFV4X1V//u9r/bCEy6E8NnQKcAxAR9weWAL+IiD+MiBFgJbBhwDVJUmmDHhG8F7goIr4ATNIEw07gg8Bimk8NfXnANUlSaQMNgsz8NfCcGe46apB1SJJ+ywvKJKm4gV9QJmnvcspZVw/t2OtOP25ox96bOCKQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqziCQpOIMAkkqzmmoCxjmNMEaHH/Omi9HBJJUnEEgScUZBJJUnEEgScUZBJJUnEEgScUZBJJUnEEgScUZBJJUnEEgScUZBJJU3IKYaygiFgHvBA4D/g94YWZ+Z7hVSVINC2VE8FTgHpn5OOB04Jwh1yNJZSyUIDgG+CRAZn4JePRwy5GkOhbEqSHgAOBnfbdvj4jRzNwx08a93tjIYMpamHq9sbu1/cfPObmjSiQN2t39/39XLJQRwW1Af3eLZgsBSdLutVCC4BrgzwAi4ijgG8MtR5LqWCinhi4HnhQRXwRGgL8dcj2SVMbI5OTksGuQJA3RQjk1JEkaEoNAkoozCCSpuIXyZrFmEBH7AOuAQ4H9gNcD3wIuAiaBTcCqzNw5pBIHIiLuC1wHPAnYQaH+I+JVwFOAfWmmYfkcRfpvX//vp3n93w68iCI//4hYBpydmSsi4qHM0HNEnAmcRPOcrM7MjfM9niOChe15wObMPBY4AXg7cC5wRrtuBNirrxZrfxm8G/hlu6pM/xGxAng8cDSwHDiYQv3TfKR8NDMfD7wOeAMF+o+IVwIXAvdoV/1OzxGxlOY1sQx4NvCOXTmmQbCwfQRY0y6P0CT/kTR/FQJcCRw/hLoG6U3Au4Aft7cr9b+S5pqay4GPA1dQq/8bgdF2UsoDgN9Qo/+bgaf33Z6p52OA9Zk5mZm30DxPvfke0CBYwDLz55m5LSLGgEuBM4CRzJz6zO824MChFdixiPgbYCIzr+pbXaZ/4D408249E3gJ8EGaq+6r9P9zmtNC3wYuAN5GgZ9/Zl5GE3pTZup5+rQ8u/RcGAQLXEQcDHwGuDgzLwH6z4eOAVuHUthgnEJzoeFngcOBDwD37bt/b+9/M3BVZv46MxP4FXf8z7639/9PNP0/nGaK+vfTvFcyZW/vf8pM/+enT8uzS8+FQbCARcRBwHrgtMxc166+vj13DHAisGEYtQ1CZj4hM5dn5grg68BfAVdW6R/4AnBCRIxExP2BewKfLtT/Fn77V+9PgX0o9PrvM1PP1wArI2JRRBxCM1K8db4H8FNDC9urgXFgTURMvVdwKvC2iNgXuIHmlFElrwAuqNB/Zl4REU8ANtL80bYK+B5F+gfeDKyLiA00I4FXA1+lTv9Tfuc1n5m3t8/Ltfz2tTFvTjEhScV5akiSijMIJKk4g0CSijMIJKk4g0CSijMIpN0kIj7b93lvaY9hEEhScV5QJs0hIh5IM8fPPWku9X8ZzSygrwD2b/+9MDM/P+1xpwPPAhYDVwGn0UwD8CHgfu1mr83M/xxAG9KcHBFIc3sBcEVmPhp4JfAEmgng/jwzDwPOAv65/wERcQLNjJGPAY4AHgA8F3ga8P3MPJJmivFjB9WENBdHBNLcPgV8NCKOAD5BMwPmu4AnR0QAK2i+NKXf8TTzxF/X3t4fuIXmS4beGBEPaPf1r51XL90FjgikOWTmNcAf05ze+Uua+eC/AjwY+Dzt1MjTHrYYeEtmHp6Zh9OEwhsy8ybgETSnmo4FNkbE9MdKA+eIQJpDRPwb8OPMfEtEfAb4Ac38+G9sN7mA5hd/v6uB10XEe2imjv4YcFFE/B7wkMx8eURcSTNKOJAaUylrATMIpLmdB1zSfknO7TRfC/hUmjDYTvPNUQ/qf0BmfjwiDgO+TBMSn6SZS38M+FBEfIPmi0fWZqYhoKFz9lFJKs73CCSpOINAkoozCCSpOINAkoozCCSpOINAkoozCCSpuP8HeDpQUbYXjkkAAAAASUVORK5CYII=\n"
]
},
"metadata": {}
}
],
"execution_count": 24,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"Considering the weather model and how the sales depends on it, the most likely value is around 50. \n",
"\nBoth the way the story is assembled and how concrete the result is strikes me as fantastically clear. Bayesian reasoning and PPL can be abstract at first but this example brings it all together for me."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [],
"outputs": [],
"execution_count": null,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
}
],
"metadata": {
"kernel_info": {
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.6.3",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3"
},
"nteract": {
"version": "0.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment