Skip to content

Instantly share code, notes, and snippets.

@viniciusmss
Last active September 29, 2018 12:50
Show Gist options
  • Save viniciusmss/8eba98d01e8666271cc10e9566b42fef to your computer and use it in GitHub Desktop.
Save viniciusmss/8eba98d01e8666271cc10e9566b42fef to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Normal-inverse-gamma in SciPy\n",
"\n",
"Even though SciPy does have classes defined for the normal distribution (`scipy.stats.norm`) and the inverse-gamma distribution (`scipy.stats.invgamma`), it does not have one defined for the normal-inverse-gamma distribution. To help you, the functions below implement the pdf and a sampler for the normal-inverse-gamma distribution."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from scipy import stats"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"'''\n",
"Function definitions for the normal-inverse-gamma distribution. The parameters\n",
"of the distribution, namely mu, lambda / nu, alpha, beta, are as defined here:\n",
"\n",
" https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution\n",
"\n",
"Note that we use the symbol nu (ν) rather than lambda (λ) for the third parameter.\n",
"This is to match the notation used in the conjugate priors table on Wikipedia:\n",
"\n",
" https://en.wikipedia.org/wiki/Conjugate_prior#Table_of_conjugate_distributions\n",
"'''\n",
"\n",
"def norminvgamma_pdf(x, sigma2, mu, nu, alpha, beta):\n",
" '''\n",
" The probability density function of the normal-inverse-gamma distribution at\n",
" x (mean) and sigma2 (variance).\n",
" '''\n",
" return (\n",
" stats.norm.pdf(x, loc=mu, scale=np.sqrt(sigma2 / nu)) *\n",
" stats.invgamma.pdf(sigma2, a=alpha, scale=beta))\n",
"\n",
"def norminvgamma_rvs(mu, nu, alpha, beta, size=1):\n",
" '''\n",
" Generate n samples from the normal-inverse-gamma distribution. This function\n",
" returns a (size x 2) matrix where each row contains a sample, (x, sigma2).\n",
" '''\n",
" sigma2 = stats.invgamma.rvs(a=alpha, scale=beta, size=size) # Sample sigma^2 from the inverse-gamma\n",
" x = stats.norm.rvs(loc=mu, scale=np.sqrt(sigma2 / nu), size=size) # Sample x from the normal\n",
" return np.vstack((x, sigma2)).transpose()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pre-class work to prepare for class activities:\n",
"\n",
"In this session we model a data set using the normal likelihood with normal-inverse-gamma prior. Reuse your code from the previous class so that you can calculate posterior hyperparameters from the prior hyperparameters and data. Have this code ready for a class activity.\n",
"\n",
"\n",
"In the previous session you were told which prior hyperparameters to use. This time you have to choose them yourself. Given the information below, find reasonable values for the prior hyperparameters of the normal-inverse-gamma distribution — μ₀, ν₀, α₀, β₀. You will be asked to provide your values for the prior hyperparameters in class, and to explain how you came up with them.\n",
"\n",
"\n",
"Frame the information below as an optimization problem. You should design a function that is minimized when the constraints below are satisfied.\n",
"\n",
"The data are normally distributed. The error margins given below represent 1 standard deviation from mean of the parameter.\n",
"Constraint: the mean of the data is approximately 2.3 ± 0.5.\n",
"Constraint: the variance of the data is approximately 2.75 ± 1.\n",
"Find μ₀, ν₀, α₀, β₀ hyperparameters for the normal-inverse-gamma prior that match this information.\n",
"Paste the values that you arrived at for the hyperparameters of the normal-inverse-gamma prior in a Google Doc and explain how you determined that those values are reasonable given the information you were provided with. Be ready to paste a link to your document into a class poll."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Started at [1. 1. 2.1 1. ]\n",
"Ended at [ 2.30000022 11.00000092 9.56250409 23.54689 ]\n",
"f([ 2.30000022 11.00000092 9.56250409 23.54689 ]) = 0.000000\n"
]
}
],
"source": [
"'''\n",
"From Wikipedia: https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution\n",
"E[x] = mu\n",
"E[sigma^2] = b/(a-1)\n",
"Var[x] = b/(a-1)lambda (for a > 1)\n",
"Var[sigma^2] = b^2/((a-1)^2 * a-2) (for a > 2)\n",
"'''\n",
"\n",
"\n",
"from scipy.optimize import minimize\n",
"import numpy as np\n",
"\n",
"def f(x):\n",
" mu, nu, alpha, beta = [*x]\n",
" E_X = mu\n",
" E_Sigma = beta / (alpha - 1)\n",
" Var_X = beta / ((alpha - 1) * nu)\n",
" Var_Sigma = beta ** 2 / ((alpha - 1)**2 * (alpha - 2))\n",
" return (E_X - 2.3)**2 + (E_Sigma - 2.75)**2 + (Var_X - 0.25)**2 + (Var_Sigma - 1) **2 \n",
"\n",
"x_initial = np.array([1, 1, 2.1, 1])\n",
"result = minimize(f, x_initial)\n",
"x_final = result.x\n",
"print('Started at', x_initial)\n",
"print('Ended at', x_final)\n",
"print('f(%s) = %.6f' % (x_final, f(x_final)))"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2.29898905 2.74861713]\n",
"[0.50006818 0.99788304]\n"
]
}
],
"source": [
"mu_0, nu_0, alpha_0, beta_0 = [*x_final]\n",
"samples_1e6 = norminvgamma_rvs(mu = mu_0, nu = nu_0, alpha = alpha_0, beta = beta_0, size=1000000)\n",
"print(np.mean(samples_1e6, axis=0))\n",
"print(np.std(samples_1e6, axis=0))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment