Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save twolodzko/6d30aed57276593556ce8d83e241bd88 to your computer and use it in GitHub Desktop.

Select an option

Save twolodzko/6d30aed57276593556ce8d83e241bd88 to your computer and use it in GitHub Desktop.
Trying the RMS alsorithm from the Uncertainty in Neural Networks: Approximately Bayesian Ensembling paper by Pearce et al (2020) on trivial example of estimating mean for normal distribution with known variance
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Trying the RMS alsorithm from the [*Uncertainty in Neural Networks: Approximately Bayesian Ensembling*](https://arxiv.org/abs/1810.05546v5) paper by Pearce et al (2020) on trivial example of estimating mean for normal distribution with known variance. In this case we know the exact solution for the problem, since it can bre calculated using [conjugacy](https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import scipy.stats as sp\n",
"from scipy.optimize import minimize"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"\n",
"n = 15\n",
"\n",
"# true parameters\n",
"μ = 5\n",
"σ = 2.7\n",
"\n",
"x = sp.norm(μ, σ).rvs(n)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"prior_μ = 0\n",
"prior_σ = 10\n",
"prior_dist = sp.norm(prior_μ, prior_σ)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(5.003623405947105, 0.6954491092861294)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def posterior(x, prior_μ, prior_σ):\n",
" n = len(x)\n",
" σ2 = σ ** 2\n",
" prior_σ2 = prior_σ ** 2\n",
" \n",
" # see: https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution\n",
" post_σ2 = 1 / (1/prior_σ2 + n/σ2) \n",
" post_μ = post_σ2 * (prior_μ/prior_σ2 + np.sum(x)/σ2)\n",
" \n",
" return float(post_μ), np.sqrt(post_σ2)\n",
"\n",
"post_μ, post_σ = posterior(x, prior_μ, prior_σ)\n",
"post_dist = sp.norm(post_μ, post_σ)\n",
"\n",
"post_μ, post_σ"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[4.9764284 , 0.69544911],\n",
" [4.95463788, 0.69544911],\n",
" [5.01882196, 0.69544911],\n",
" ...,\n",
" [5.00404823, 0.69544911],\n",
" [4.99212787, 0.69544911],\n",
" [5.0073286 , 0.69544911]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"R = 5000\n",
"results = []\n",
"\n",
"for _ in range(R):\n",
" acc_μ = prior_dist.rvs(1)\n",
" results.append(posterior(x, acc_μ, prior_σ))\n",
" \n",
"results = np.vstack(results)\n",
"results"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(5.003886029404278, 0.048175803294772626)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(results[:, 0]), np.std(results[:, 0])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"grid = np.linspace(post_μ - 2*post_σ, post_μ + 2*post_σ, 1000)\n",
"\n",
"plt.hist(results[:, 0], density=True, label=\"RMS approximation\")\n",
"plt.plot(grid, post_dist.pdf(grid), label=\"True posterior\")\n",
"plt.title('Posterior for $\\mu$')\n",
"plt.legend()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment