Skip to content

Instantly share code, notes, and snippets.

@piyush01123
Last active November 18, 2019 08:00
Show Gist options
  • Save piyush01123/81d7e05b5b32e41fb2e83c2d89e68722 to your computer and use it in GitHub Desktop.
Save piyush01123/81d7e05b5b32e41fb2e83c2d89e68722 to your computer and use it in GitHub Desktop.
Gaussian Mixture Model
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import scipy.stats as stats\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simulating Points \n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x1188a4f50>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"n_points = 100\n",
"\n",
"actual_means = [[2, 4], [-5, 3], [0, -4]]\n",
"actual_cov_mats = [[[1, 0], [0, 1]], [[1, 0], [0, 1]], [[1, 0], [0, 1]]]\n",
"\n",
"blobs = [np.random.multivariate_normal(m, s, (n_points,)) for m, s in zip(actual_means, actual_cov_mats)]\n",
"X = np.concatenate(blobs)\n",
"\n",
"plt.scatter(X[:,0], X[:, 1], color='b', s=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gaussian Mixture Model \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In Gaussian Mixture Model or GMM, we assume that our data distribution is given by K gaussians and say that \n",
"\n",
"\\begin{equation}\n",
"P(X) = \\sum_{k=1}^K \\pi_k P(X | \\mu_k, \\sigma_k)\n",
"\\end{equation}\n",
"\n",
"where $P(X | \\mu, \\sigma)$ is the normal probability distribution given by\n",
"\\begin{equation}\n",
"P(X | \\mu, \\sigma) = \\frac {1}{\\sqrt{2\\pi} \\sigma} e^{\\frac{-{(X-\\mu)}^2}{2\\sigma^2}}\n",
"\\end{equation}\n",
"\n",
"which in higher than 1-D translates to \n",
"\\begin{equation}\n",
"P(X | \\mu, \\Sigma) = \\frac {1}{\\sqrt{2\\pi} \\Sigma^{\\frac 1 2}} e^{\\frac{-{(X-\\mu)}^T \\Sigma^{-1} (X-\\mu)}{2}}\n",
"\\end{equation}\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Our EM algorithm \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assume K as some positive integer (we can assume 3 in our case).\n",
"\n",
"Step 1: Initialize means $\\mu_k$, covariance matrices $\\Sigma_k$ and weights $\\pi_k$ for k=1 to K.\n",
"\n",
"### E step:\n",
"\n",
"Step 2: Calculate $P(X_i|\\mu_k, \\Sigma_k)$ for each point $X_i$ for each gaussian. This is our familiar multivariate normal distribution p.d.f.\n",
"\\begin{equation}\n",
"P(X_i | \\mu_k, \\Sigma_k) = \\frac {1}{\\sqrt{2\\pi} \\Sigma_k^{\\frac 1 2}} e^{\\frac{-{(X_i-\\mu_k)}^T \\Sigma_k^{-1} (X_i-\\mu_k)}{2}}\n",
"\\end{equation}\n",
"\n",
"This is a matrix of size NxK where N is the number of points and K is the assumed number of gaussians. \n",
"\n",
"Step 3: Normalize this matrix so that each row sums up to 1. Each element $P_{ik}$ in this matrix represents probability of $i^{th}$ point being associated to $k^{th}$ gaussian.\n",
"\n",
"\n",
"### M step:\n",
"\n",
"Step 4: Re-calculate the parameters of the model ie $\\mu_k$, $\\Sigma_k$ and $\\pi_k$ for k=1 to K. We can do this as a calculation of weighted means, weighted covariance matrices and $\\pi_k$s are simply the normalized sum of each column in the $P$ matrix calculated in the E step.\n",
"\n",
"\\begin{equation}\n",
"\\mu_k = \\frac{\\sum_{i=1}^N P_{ik}X_i}{\\sum_{i=1}^N P_{ik}}\n",
"\\end{equation}\n",
"\n",
"\\begin{equation}\n",
"\\Sigma_k = \\frac {\\sum_{i=1}^N P_{ik}{(X_i-\\mu_k)}^T(X_i-\\mu_k)}{\\sum_{i=1}^N P_{ik}}\n",
"\\end{equation}\n",
"\n",
"\\begin{equation}\n",
"\\pi_k = \\frac {1}{N} {\\sum_{i=1}^N P_{ik}}\n",
"\\end{equation}\n",
"\n",
"\n",
"##### Repeat steps 2 to 4 till a convergence criteria is met or for a fixed number of iterations\n",
"In our case we simply repeat for 100 iterations."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"pi_s = [1/3]*3\n",
"mu_s = [[-5,-5], [0,-5], [0,5]]\n",
"sigma_s = [[[1,0],[0,1]], [[1,0],[0,1]], [[1,0],[0,1]]]\n",
"\n",
"history = []\n",
"history.append(mu_s.copy())\n",
"\n",
"\n",
"n_iterations = 100\n",
"for itr in range(n_iterations):\n",
" print(\"Mu\", mu_s)\n",
" print(\"SIGMA\", sigma_s)\n",
" print(\"PI\", pi_s)\n",
"\n",
" # E Step\n",
" # In E step we are recalculating the probabilities of each point being associated with each gaussian\n",
" # Notice that we can compare this step to the step in K-Means where we are assigning each point to a cluster except that here it is a soft assignment\n",
" rand_vars = [stats.multivariate_normal(mu, sigma) for mu, sigma in zip(mu_s, sigma_s)]\n",
" gamma_mat = np.array([[pi*rand_var.pdf([x]) for pi, rand_var in zip(pi_s, rand_vars)] for x in X] ) #3x100\n",
" probs = gamma_mat/np.sum(gamma_mat, axis=1).reshape((300,1))\n",
"# print(\"PROBS\\n\", probs)\n",
"\n",
" # M step\n",
" # In M step we are re-calculating the parameters ie the mu_s, the sigma_s and the pi_s\n",
" # We can compare this to re-calculating the centroids in the K-Means algorithm\n",
" for i, (mu, sigma) in enumerate(zip(mu_s, sigma_s)):\n",
" mu_s[i] = np.sum(probs[:,i].reshape((300,1)) * X, axis=0) / np.sum(probs[:,i])\n",
" sigma_s[i] = (X-mu_s[i]).T.dot((X-mu_s[i])*probs[:,i].reshape((300,1))) / np.sum(probs[:,i])\n",
" pi_s[i] = np.mean(probs[:,i])\n",
" \n",
" history.append(mu_s.copy())\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualization of the means of the gaussians with iterations"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"history = np.array(history)\n",
"for i in range(len(history)):\n",
" fig = plt.figure(figsize=(10,6))\n",
" ax = fig.add_subplot(111)\n",
" ax.set_xlim(-10,10)\n",
" ax.set_ylim(-10,10)\n",
" ax.set_title(\"GMM visualization\")\n",
" ax.scatter(history[i,:,0], history[i,:,1], color='b', s=50)\n",
" ax.legend(\"\")\n",
" fig.savefig(\"plot_{}.png\".format(i))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"from PIL import Image\n",
"pngs = [Image.open(\"plot_{}.png\".format(i)) for i in range(20)]\n",
"\n",
"gif = Image.new(\"RGBA\", (pngs[0].width, pngs[0].height), (255,255,255))\n",
"\n",
"gif.save(fp=\"comb.gif\", format='GIF', append_images=pngs,\n",
" save_all=True, duration=1000, loop=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![test](https://s5.gifyu.com/images/comb.gif)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment