Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ltiao/95a8e514291f118ffa4e6e143674f989 to your computer and use it in GitHub Desktop.
Save ltiao/95a8e514291f118ffa4e6e143674f989 to your computer and use it in GitHub Desktop.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_probability as tfp\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"from ipywidgets import interact"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"log_n = 9\n",
"probs = 0.8\n",
"temperature = 0.5\n",
"\n",
"num_samples = 5000"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"plt.style.use('seaborn-colorblind')\n",
"\n",
"plt.rc('text', usetex=True)\n",
"plt.rc('font', family='serif', serif=['Lato'], size=16)\n",
"plt.rc('animation', convert_path='/usr/bin/convert')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'TensorFlow version: 1.10.0'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"'TensorFlow version: ' + tf.__version__"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"tf.enable_eager_execution()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"golden_size = lambda width: (width, 2. * width / (1 + np.sqrt(5)))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"x = tf.linspace(0., 1., 1 << log_n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Bernoulli"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "24f290e7c8fd40ccb42905205694f3ef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.5, description='probs', max=0.99, min=0.05, step=0.05), Output()), _…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"@interact(probs=(0.05, 0.99, 0.05))\n",
"def plot_bernoulli(probs):\n",
" \n",
" d = tfp.distributions.Bernoulli(probs=probs)\n",
" \n",
" fig, ax = plt.subplots(figsize=golden_size(8))\n",
"\n",
" ax.set_title(r'$p(x) = \\mathrm{{Bernoulli}}(\\rho={:0.2f})$'.format(probs))\n",
"\n",
" ax.plot(x.numpy(), d.prob(x).numpy(), linestyle='--')\n",
" ax.axvline(x=0., ymin=0., ymax=d.prob(0.).numpy(), marker='o', color='k')\n",
" ax.axvline(x=1., ymin=0., ymax=d.prob(1.).numpy(), marker='o', color='k')\n",
"\n",
" ax.set_xlabel(r'$x$')\n",
"\n",
" ax.set_ylabel(r'$p(x)$')\n",
" ax.set_ylim(0., 1.1)\n",
"\n",
" plt.show() "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6ef5ec09ff140fb98aa198aa30754d6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.5, description='probs', max=0.99, min=0.05, step=0.05), Output()), _…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"@interact(probs=(0.05, 0.99, 0.05))\n",
"def plot_bernoulli_samples(probs):\n",
"\n",
" d = tfp.distributions.Bernoulli(probs=probs)\n",
"\n",
" fig, ax = plt.subplots(figsize=golden_size(8))\n",
"\n",
" ax.set_title(r'$p(x) = \\mathrm{{Bernoulli}}(\\rho={:0.2f})$'.format(probs))\n",
" \n",
" sns.distplot(d.sample(num_samples), kde=False, hist=True, ax=ax)\n",
"\n",
" ax.set_ylim(0, num_samples)\n",
" ax.set_ylabel('nbr. of samples')\n",
" ax.set_xlabel(r'$x$')\n",
"\n",
" plt.show() "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Relaxed Bernoulli (Binary Concrete)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4b32e4cb30e14b52b24e77434b403dea",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.5, description='probs', max=0.99, min=0.05, step=0.05), FloatSlider(…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"@interact(probs=(0.05, 0.99, 0.05), temperature=(0.01, 3.0, 0.1))\n",
"def plot_relaxed_bernoulli(probs, temperature):\n",
"\n",
" d = tfp.distributions.RelaxedBernoulli(probs=probs, temperature=temperature)\n",
"\n",
" fig, ax = plt.subplots(figsize=golden_size(8))\n",
"\n",
" ax.set_title(r'$p(x) = \\mathrm{{BinConcrete}}(\\rho={:0.2f}, \\tau={:0.2f})$'.format(probs, temperature))\n",
"\n",
" ax.plot(x.numpy(), d.prob(x).numpy())\n",
"\n",
" ax.set_xlabel(r'$x$')\n",
"\n",
" ax.set_ylabel(r'$p(x)$')\n",
" ax.set_ylim(0., 5.)\n",
"\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d3fbe87f562b42cba6a12f0060332ee5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.5, description='probs', max=0.99, min=0.05, step=0.05), FloatSlider(…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"@interact(probs=(0.05, 0.99, 0.05), temperature=(0.01, 3.0, 0.1))\n",
"def plot_relaxed_bernoulli_samples(probs, temperature):\n",
"\n",
" d = tfp.distributions.RelaxedBernoulli(probs=probs, temperature=temperature)\n",
"\n",
" fig, ax = plt.subplots(figsize=golden_size(8))\n",
"\n",
" ax.set_title(r'$p(x) = \\mathrm{{BinConcrete}}(\\rho={:0.2f}, \\tau={:0.2f})$'.format(probs, temperature))\n",
"\n",
" sns.distplot(d.sample(num_samples), kde=False, hist=True, ax=ax)\n",
" \n",
" ax.set_ylim(0, num_samples)\n",
" ax.set_ylabel('nbr. of samples')\n",
" ax.set_xlabel(r'$x$')\n",
"\n",
" plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment