Skip to content

Instantly share code, notes, and snippets.

@piyush01123
Last active November 18, 2019 08:00
Show Gist options
  • Save piyush01123/81d7e05b5b32e41fb2e83c2d89e68722 to your computer and use it in GitHub Desktop.
Save piyush01123/81d7e05b5b32e41fb2e83c2d89e68722 to your computer and use it in GitHub Desktop.
Gaussian Mixture Model
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import scipy.stats as stats\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simulating Points \n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x1188a4f50>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD8CAYAAABq6S8VAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAX9ElEQVR4nO3dX4wkV3XH8d/BXkT4EzvsrmJie2OjxJEWg0V2zHqDSACjzEBsNlrlARAIwsMmKBBIiCzAUkbz4iBAJEiQMKPgiAgnhJglyIh4bSOEFM2w0GsMsbFBFjHGDsiDIwsEEY7FyUN1a3vb/aeq7q2qe6u/H2k0Mz3Tt251VZ06fe6tanN3AQDy9ZSuOwAACEMgB4DMEcgBIHMEcgDIHIEcADJHIAeAzEUJ5Gb2Z2Z2j5ndbWb/bGZPi9EuAGCx4EBuZhdK+lNJK+5+uaRzJL0mtF0AQDmxSivnSvoFMztX0tMl/XekdgEAC5wb2oC7P2xmH5D0oKT/lXSbu9827zn79u3zSy65JHTRALBUTp8+/UN33z/5eHAgN7NfknRU0qWSHpP0r2b2enf/xMT/HZd0XJIOHDigwWAQumgAWCpm9t1pj8corbxC0n+5+667/5+kE5J+a/Kf3H3L3VfcfWX//iedUAAANcUI5A9KusrMnm5mJulqSfdGaBcAUEJwIHf3U5JulnSnpP8ctrkV2i4AoJzgGrkkufu6pPUYbQEAquHKTgDIHIEcADJHIAeAzBHIAXRqZ0daWyu+o54og50AUNfGhnTyZPHzrbd225dcEcgBdGp9/ezvqI5ADqBTR46QiYeiRg4AmSOQA0DmCOQAkDkCOQBkjkAOAJkjkANA5gjkAJA5AjkAZI5ADgCZI5ADQOYI5ACQOQI5AAzlektdbpoFAEO53lI3SkZuZueb2c1mdp+Z3WtmR2K0CwAhqmbY6+vS6mp+t9SNlZF/SNKt7v4HZvZUSU+P1C4A1FY1w871lrrBGbmZnSfptyV9TJLc/XF3fyy0XQCYp0y2nWuGXVWMjPxSSbuS/sHMrpB0WtLb3f0nEdoGgKnKZNu5ZthVxaiRnyvpNyX9nbu/UNJPJL1r8p/M7LiZDcxssLu7G2GxAJbFtOx7WbLtMmIE8ockPeTup4a/36wisJ/F3bfcfcXdV/bv3x9hsQBytrUl7dtXfJfml0pG2ffGxpnHRtn2EaZWhJdW3P0HZvY9M/sNd/+WpKslfTO8awD67D3vkR59tPh+/Pj8Ugkf0DxfrFkrb5N003DGynck/WGkdgH01A03FEH8hhuK3+cF62Wpdddl7t76QldWVnwwGLS+XABo0s5O8c5ifb2Zko+ZnXb3lcnHubITACLp6spQ7rUCLLlc7y/SlJDXo6uZNGTkwJLL9f4iTQl5Pbqq5fcmIyerAOppM4vM4TjNcX56bwY719aKs+jqKlkFkCqO0zC9H+xknimQPo7TZvSmtLJMV3nl8PYUmKat4zSFY6TNPvQmkC+TaZcrpy6FAwvx7exIV11VfKW0bVM4RtrsQ29KK11oevL/LDm+PWVmRD9tbEinTp35OZVtm8Ix0mof3L31r0OHDnkfrK66S8V3zLe9XbxO29td9yQNOb8e433f3nY/fLj4ynFdciNp4FNiKoE8QFMHY84HOcrJOQlY1Pe+7r8prBeBPCOHDxdb5vDhrnsyXwo7dq5ivnZtb4dFyytzkspx30nh5Esgr6DrnSyXQJ7Cjo3w7RB7fy/TXo77TtdxwZ1AXknXO1kKO0wZufSz70K3Qxf7O/tOPQTyClLbyXJ+G470pbZPpNSfun1pah0I5Jna3nbfuzdextT1uw1gkbr76GTwrBpMp/1/3b40dZzNCuTMI0/cxkbxcVh798aZj5rC/FrkYfw6Cam9aybq7qOjaxUee0w6//zi+2iOe5n57dOudTh2TBoMiu9VtH6cTYvuTX81mZGn9LbMPbw/055PqQVtGM8qRz/v3dv9vjJrnx09Pj5ZoMoc98n58eNtpfIOVstSWkmtdNBEf2K2mdrr1bRlP3FVWf/x/93cdN+zp9q+0tRrPb7PLkp0Qksjhw+ntb8sTSBfdLZue4M0sVwy8vr6dOKqs+1CA1uVjLzKskbrsrm5eJ2qBOrUBitDLU0gn6VPBzDqS/UAraPOPh07sM1rr8qyxk8UKWT9qWo8kEs6R9LXJH1u0f92EciXbYOj/1LYp2MlSFUy8mXWRiD/c0n/lGog70JOJZA67acQSNCtFPeB1MuZIRoN5JIukvQFSS8nkJ+R06BknfZTuzQcT5ZiAhDa7qJlpjbBIOZr1HQgv1nSIUkvnRXIJR2XNJA0OHDgQPgaZSCljHzR87vIyBm3iCvmBS1lTbYfa7rsrH6XuUAuVt1+3vPq1P9jbIPGArmkayT97fDnmYF8/CvFjLxPs1qmSTFoTpu3S3Ze37RtHOt1LTvYOW12S51+zfp7ndkz054fehzUmZGTdEYu6a8kPSTpAUk/kPRTSZ+Y95wUA3lXgS72YFHMjLtNKZ5oclNmG9fdD8pun2kZ8+Zm8djmZvX2Jvvc9DvTttupqpXph2TkzS93UaaSayBM/UTTF3X3kyrbZ1aWXvedQu77dkwE8qFUAkbZfpQ5KKq016QU+oD5uthGdevLTEl8sqW/IGgklbN72X5MfshEysEyldcWzQsZ1Cy7n8z7v1iDqrlZikDeZI0wtrL9SOHTguq+e0C3mtwe04JslTp66P4UsvycLUUg72JDNh286kylitGn8TaW4QDpoya3W5sZcdlldXUstpnALEUg7yIj7DLIzVp2jD6Nt0Gmnac2t1vb2X8b61Z2fKrNGLAUgbxpTWQBdZ6/aBAoRomJ4I0q2s7+2yg5Tq4TGXlPAnnsGmCVNkOfU6cNgnk6Ut8WdUqAIdoI5DH7HastAnkEZTPd0QURZa4+q5M9x66Bz0JtPB25bItp+1UTfU/9xDZp9BqEfsISgbwlow1W9dNUyrTZ9kGc28HSZ13UhOvoqp6duvEEL+Q4JpC3YHv7zGcExryIoeqBwIGDOmIkDDlNAW5bjPUmkLegyoHQZP1tWj82N93PO8/94EEOsq6k/tq21b9cykQpIpC3oKlBzkXtlhldH72tW7RMDrLm5P7aVp3plMIsjzpS7h+BPIKuRrHHR+jrToHc3HR/xjOKr+uum32AjUpDKe7EuUs5QJQx60Q0Wq/Rfjr6e64nrpT7TSCPYNEGbqqWPR7Iy85tndf/WYMuKe/A6N6sfW2034z2z5DZVSmc7FLowyy9CORdv8BVSxyLlB3hH39sMmuuU6IJuZAImBRzvyGZmK8XgTyVjVy19lfl8TJZ/3hGneJFC0BdbV0jkateBPJUNlBI5l11wGhWW6EXFizqJxBbk8fveNt93o97Echjq7tjhWTes+qJTfW5yv+ncqJEPzUZYCeTpa4H7Zs6lgjkU8Tescq0N22Ev6mNPlmGAbrUVkbu3n1W3tTyCeRTxN6x6ma/TW30yTJMGxk3WX36lmEbdb2OZOQdaDNTaPJ/Fj2njSyl60wIi7GN8kUgn6Ot2l3X7ZCRw51t1KbYr/WsQH6uApnZxZL+UdIvS3JJW+7+odB227S+fvb3FNuO0c6RI9Ktt4b1I4VlIAzbqD0bG9LJk8XPTb7mVgT5gAbMniPpOe5+p5k9S9JpSb/v7t+c9ZyVlRUfDAZBy+2DnZ1iQ6+vFwcXgH6JfYyb2Wl3X5l8PDgjd/fvS/r+8Ocfm9m9ki6UNDOQo9jA114rPfpo8TsZEtA/bb37eUrMxszsEkkvlHRqyt+Om9nAzAa7u7sxFxvdzo60tlZ8b8rGRhHE9+5tpqQDYHlEC+Rm9kxJn5b0Dnf/0eTf3X3L3VfcfWX//v2xFtuIUV1rY6O5ZayvS6ur0i23NFtWaeOkBKBbwaUVSTKzPSqC+E3ufiJGm20br2U1Ofg50tZbrrYGWwB0J8asFZP0MUn3uvsHw7vUjcmA15eg18ZJCUC3YpRWXizpDZJebmZ3Db9eFaHdVo1KHX0LeKPMn1kxWBbLWE6MMWvlPyRZhL50Knapg6mFQDeWsZwYddbKMll01m9jwDS2Zcxk0D99fXc9T5TBzmW06KyfY216GTMZ9M8yXrlKIK9pUaDucmeqW9bJ8eQDIMIl+nVwiX6z1taKzHp1dfkyE6DPZl2iT408USH16mWsEQLLjEDegBiDhiGDpUw5BJYLgbwBMWaspJ5VM8MFSAeDnQ3I5d7hIZjhAqSDQN6A1INwDMxwAdJBIEcty3CyAnJBjTwzTdamqXsDeSIjz0yTtWnq3kCeCOSZyeGDogG0iys7ASATXNkJAD1FIAeAzBHIASBzBHIAyByBHAAyRyAHgMxFCeRmtmZm3zKz+83sXTHaBACUExzIzewcSR+R9EpJByW91swOhrYLACgnRkb+Ikn3u/t33P1xSZ+UdDRCuwCAEmIE8gslfW/s94eGj53FzI6b2cDMBru7uxEWCwCQWhzsdPctd19x95X9+/e3tVgswB0PgfzFuGnWw5IuHvv9ouFjyAB3PATyFyOQf1XSr5vZpSoC+GskvS5Cu2gBdzwE8hccyN39CTN7q6STks6RdKO73xPcM7SCT/oB8hflfuTu/nlJn4/RFgCgGq7sBIDMEcgBIHMEcgDIHIEcADJHIAeAzBHIASBzBHIAyByBHAAyRyAHgMwRyAEgcwRyAMgcgRxoCPd6R1ui3DQLwJNxr3e0hUAONIR7vaMtBHKgIdzrHW2hRg4AmSOQA0DmCOQAkDkCOQBkjkAOAJkLCuRm9n4zu8/MvmFmnzGz82N1DABQTmhGfruky939BZK+Lend4V0CEANXli6PoEDu7re5+xPDX78s6aLwLgH9tSi4xgy+oytLNzaaaR/piHlB0Jsl/UvE9oDeWXTZfszL+qddWcptA/ppYSA3szskXTDlT9e7+2eH/3O9pCck3TSnneOSjkvSgQMHanUWyN2xY9JgUHyfJuZl/dOuLOW2Af1k7h7WgNmbJP2RpKvd/adlnrOysuKDwSBouUDXdnaKDHd9vQiaZaytFRnx6ioZMaozs9PuvjL5eOislTVJ10l6ddkgDvTFtBr0IuvrRRCfzIipXSNE6KyVD0t6lqTbzewuM/tohD4BWZgVlOcZlTsmM/iqJ4WmA/+09jnZpCtosNPdfy1WR4DcxLy7YdXadd1By7LloGntM1CaLm5jC7RoViCtelKoO2hZNhhPa5+B0nQFD3bWwWAnllUXg53jJw+p+gAt0tHIYCfQhj7VZuvU1UON199n1ejL6tO26BNKK0hen2qzXXxqUMySSJ+2RZ8QyJG8ZajN1pmTXlaXg7JoBzVyIAGTtfMmAzvyRY0cSMx4vXmydj45r5zaNOYhkAMljAfSWEF13iDkrMB+7bXS1lb7FwMhce7e+tehQ4ccyMnqqrtUfB//eXv7zPeqqjx3e9t9795iuaPvq6vVl1nG+PohLZIGPiWmMtgJlDDrApmQWRxVBiGPHJFuuaVY3rFj0okTzQ04MqCZHwY7gQBND0qGtM+Aaf8w2Ak0IPQCm0Xq3GGx7nPbrI1Th4+L0gqQsDpljlEmPvrwiqZvxFUHFxbFRUYOtGBrS9q3r/gulc9I62T8oyB54sT0585adpu3D+jiVgW9Nm0EtOkvZq1g2YzPOHFvdmbIotkwTc9K2d52P3y4+KozmydlIbOUYtCMWSsEcmBMjAN1Whubm0UQ39yMt5y6/Smz7JD+jU4UfZzC2PXUTAI5UEKMA7Xrg308CG9uuu/ZU70/IetARt4cAjlQQlMZedW2Y2TEq6tnSjp79ixua3yZXQcsTEcgBxpWtzY97XmhGfF4Rj5e0pmn63cSWIxAjqXRVTa5KBDO6te05zWxDovaJAtPX6OBXNI7JbmkfWX+n0COJnWVWdYNhFWfV3c5ZNz5ayyQS7pY0klJ3yWQIwWxMsumZ3fUVTcgk3Hnr8lAfrOkKyQ9QCBHn5QJmLGz3Jgnjz4E7hgDx33SSCCXdFTSh4Y/E8jRK11k5DFPDDmWUiZfz1nrkOO6xVA7kEu6Q9LdU76OSjol6TwvEcglHZc0kDQ4cOBAu2sPZKLOiaGrrLWJ9icDNBn52aJn5JKeL+mRYQB/QNITkh6UdMGi55KRIyepBI0qs17aWHaV5S5TOahJjU8/pLSCPog9p7uNfrQR/EKnSC5rKSS2WYGc29gCY6bdXrWLT8yp0o8qnzRU12iZx44Vd04cfVhF2eXyqUMNmxbdm/4iI0eqmrxEv+02mkBm3S2RkQOLxchuY3xoQhtZdh1k1mnigyWAyKZ9aEJfPtqs6Y+2Qz0EciCyacEu5LM3R7o4GZRdZl9OVLkikAM1VA1cMT7aLMbJYNKi9Si7zMn/I7C3bFrhvOkvBjuRuy4G/arcknae8YHUundsXPR/sW7Di7OJ29gC8eR0s6x57TS1HrE+GCN227kjkAOZa/Oujk2b14e6H9CxDAjkQEUpBLx5Uu/fPJRe6pkVyBnsBGZoYnAxpjr9izUIGdpOyOAvUyCnmBbdm/4iI0cOUsz8Qj8guezdBau2g3aI0gpQXWrBfF4ALVN33tyMM7sktddlWcwK5FyiD8wR43L7mOZdIj+vr7P+VveS+1RvIbCsqJEDc8S4kGeaeTXmeX+bVx+e19em1gNpsCJbb9fKyooPBoPWlwukYm2tyJBXV5+c2c77W5v9QHrM7LS7r0w+TmkF6MC8kkabdxjkbob9QEYOAJmYlZFTI0evLNPNmnJc1xz7nANKK+iFnZ1iZsZjj0mnThWP9b3mm9qMmjJy7HMOCOTohVGAOHx4eWZn5FjfzrHPOSCQoxfGA8SyXLqd41zuHPucAwI5eoEAgWUWPNhpZm8zs/vM7B4ze1+MTgEAygsK5Gb2MklHJV3h7s+T9IEovQLQO8xYaU5oaeUtkt7r7j+TJHd/JLxLAPqIGSvNCS2tXCbpJWZ2ysy+ZGZXzvpHMztuZgMzG+zu7gYuFkBuuN9LcxZm5GZ2h6QLpvzp+uHzny3pKklXSvqUmT3Xp1wu6u5bkrak4srOkE4DyA8D0s1ZmJG7+yvc/fIpX5+V9JCkE8Nb5X5F0s8l7Wu60wCoOeOM0Br5v0l6maQvmtllkp4q6YfBvQKwEDVnjIQG8hsl3Whmd0t6XNIbp5VVAMTHVZIYCQrk7v64pNdH6guACqg5Y4S7HwJA5gjkAJA5AjkAZI5ADgCZI5ADQOYI5ACQOQI5AGTOurh+x8x2JX239QXHt0/9vJKV9cpLH9erj+skha/Xr7r7/skHOwnkfWFmA3df6bofsbFeeenjevVxnaTm1ovSCgBkjkAOAJkjkIfZ6roDDWG98tLH9erjOkkNrRc1cgDIHBk5AGSOQB6Bmb3NzO4zs3vM7H1d9ycmM3unmbmZZf/JT2b2/uF2+oaZfcbMzu+6TyHMbM3MvmVm95vZu7ruTwxmdrGZfdHMvjk8nt7edZ9iMrNzzOxrZva5mO0SyAOZ2cskHZV0hbs/T9IHOu5SNGZ2saTflfRg132J5HZJl7v7CyR9W9K7O+5PbWZ2jqSPSHqlpIOSXmtmB7vtVRRPSHqnux9U8VnAf9KT9Rp5u6R7YzdKIA/3FknvdfefSZK7P9Jxf2L6a0nXSerFQIq73+buTwx//bKki7rsT6AXSbrf3b8z/ICXT6pIKLLm7t939zuHP/9YRdC7sNtexWFmF0n6PUl/H7ttAnm4yyS9xMxOmdmXzOzKrjsUg5kdlfSwu3+967405M2S/r3rTgS4UNL3xn5/SD0JeCNmdomkF0o61W1PovkbFYnRz2M3HPqZnUvBzO6QdMGUP12v4jV8toq3gVdK+pSZPTeHzy5dsF7vUVFWycq8dXL3zw7/53oVb+FvarNvKM/Mninp05Le4e4/6ro/oczsGkmPuPtpM3tp7PYJ5CW4+ytm/c3M3iLpxDBwf8XMfq7ifgq7bfWvrlnrZWbPl3SppK+bmVSUIO40sxe5+w9a7GJl87aVJJnZmyRdI+nqHE62czws6eKx3y8aPpY9M9ujIojf5O4nuu5PJC+W9Goze5Wkp0n6RTP7hLtH+cxj5pEHMrM/lvQr7v6XZnaZpC9IOpB5kDiLmT0gacXds76JkZmtSfqgpN9x9+RPtPOY2bkqBmyvVhHAvyrpde5+T6cdC2RF5vBxSf/j7u/ouj9NGGbkf+Hu18Rqkxp5uBslPdfM7lYx4PTGPgXxnvmwpGdJut3M7jKzj3bdobqGg7ZvlXRSxYDgp3IP4kMvlvQGSS8fbqO7hlks5iAjB4DMkZEDQOYI5ACQOQI5AGSOQA4AmSOQA0DmCOQAkDkCOQBkjkAOAJn7f3I1cWhLLP7HAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"n_points = 100\n",
"\n",
"actual_means = [[2, 4], [-5, 3], [0, -4]]\n",
"actual_cov_mats = [[[1, 0], [0, 1]], [[1, 0], [0, 1]], [[1, 0], [0, 1]]]\n",
"\n",
"blobs = [np.random.multivariate_normal(m, s, (n_points,)) for m, s in zip(actual_means, actual_cov_mats)]\n",
"X = np.concatenate(blobs)\n",
"\n",
"plt.scatter(X[:,0], X[:, 1], color='b', s=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gaussian Mixture Model \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In Gaussian Mixture Model or GMM, we assume that our data distribution is given by K gaussians and say that \n",
"\n",
"\\begin{equation}\n",
"P(X) = \\sum_{k=1}^K \\pi_k P(X | \\mu_k, \\sigma_k)\n",
"\\end{equation}\n",
"\n",
"where $P(X | \\mu, \\sigma)$ is the normal probability distribution given by\n",
"\\begin{equation}\n",
"P(X | \\mu, \\sigma) = \\frac {1}{\\sqrt{2\\pi} \\sigma} e^{\\frac{-{(X-\\mu)}^2}{2\\sigma^2}}\n",
"\\end{equation}\n",
"\n",
"which in higher than 1-D translates to \n",
"\\begin{equation}\n",
"P(X | \\mu, \\Sigma) = \\frac {1}{\\sqrt{2\\pi} \\Sigma^{\\frac 1 2}} e^{\\frac{-{(X-\\mu)}^T \\Sigma^{-1} (X-\\mu)}{2}}\n",
"\\end{equation}\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Our EM algorithm \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assume K as some positive integer (we can assume 3 in our case).\n",
"\n",
"Step 1: Initialize means $\\mu_k$, covariance matrices $\\Sigma_k$ and weights $\\pi_k$ for k=1 to K.\n",
"\n",
"### E step:\n",
"\n",
"Step 2: Calculate $P(X_i|\\mu_k, \\Sigma_k)$ for each point $X_i$ for each gaussian. This is our familiar multivariate normal distribution p.d.f.\n",
"\\begin{equation}\n",
"P(X_i | \\mu_k, \\Sigma_k) = \\frac {1}{\\sqrt{2\\pi} \\Sigma_k^{\\frac 1 2}} e^{\\frac{-{(X_i-\\mu_k)}^T \\Sigma_k^{-1} (X_i-\\mu_k)}{2}}\n",
"\\end{equation}\n",
"\n",
"This is a matrix of size NxK where N is the number of points and K is the assumed number of gaussians. \n",
"\n",
"Step 3: Normalize this matrix so that each row sums up to 1. Each element $P_{ik}$ in this matrix represents probability of $i^{th}$ point being associated to $k^{th}$ gaussian.\n",
"\n",
"\n",
"### M step:\n",
"\n",
"Step 4: Re-calculate the parameters of the model ie $\\mu_k$, $\\Sigma_k$ and $\\pi_k$ for k=1 to K. We can do this as a calculation of weighted means, weighted covariance matrices and $\\pi_k$s are simply the normalized sum of each column in the $P$ matrix calculated in the E step.\n",
"\n",
"\\begin{equation}\n",
"\\mu_k = \\frac{\\sum_{i=1}^N P_{ik}X_i}{\\sum_{i=1}^N P_{ik}}\n",
"\\end{equation}\n",
"\n",
"\\begin{equation}\n",
"\\Sigma_k = \\frac {\\sum_{i=1}^N P_{ik}{(X_i-\\mu_k)}^T(X_i-\\mu_k)}{\\sum_{i=1}^N P_{ik}}\n",
"\\end{equation}\n",
"\n",
"\\begin{equation}\n",
"\\pi_k = \\frac {1}{N} {\\sum_{i=1}^N P_{ik}}\n",
"\\end{equation}\n",
"\n",
"\n",
"##### Repeat steps 2 to 4 till a convergence criteria is met or for a fixed number of iterations\n",
"In our case we simply repeat for 100 iterations."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"pi_s = [1/3]*3\n",
"mu_s = [[-5,-5], [0,-5], [0,5]]\n",
"sigma_s = [[[1,0],[0,1]], [[1,0],[0,1]], [[1,0],[0,1]]]\n",
"\n",
"history = []\n",
"history.append(mu_s.copy())\n",
"\n",
"\n",
"n_iterations = 100\n",
"for itr in range(n_iterations):\n",
" print(\"Mu\", mu_s)\n",
" print(\"SIGMA\", sigma_s)\n",
" print(\"PI\", pi_s)\n",
"\n",
" # E Step\n",
" # In E step we are recalculating the probabilities of each point being associated with each gaussian\n",
" # Notice that we can compare this step to the step in K-Means where we are assigning each point to a cluster except that here it is a soft assignment\n",
" rand_vars = [stats.multivariate_normal(mu, sigma) for mu, sigma in zip(mu_s, sigma_s)]\n",
" gamma_mat = np.array([[pi*rand_var.pdf([x]) for pi, rand_var in zip(pi_s, rand_vars)] for x in X] ) #3x100\n",
" probs = gamma_mat/np.sum(gamma_mat, axis=1).reshape((300,1))\n",
"# print(\"PROBS\\n\", probs)\n",
"\n",
" # M step\n",
" # In M step we are re-calculating the parameters ie the mu_s, the sigma_s and the pi_s\n",
" # We can compare this to re-calculating the centroids in the K-Means algorithm\n",
" for i, (mu, sigma) in enumerate(zip(mu_s, sigma_s)):\n",
" mu_s[i] = np.sum(probs[:,i].reshape((300,1)) * X, axis=0) / np.sum(probs[:,i])\n",
" sigma_s[i] = (X-mu_s[i]).T.dot((X-mu_s[i])*probs[:,i].reshape((300,1))) / np.sum(probs[:,i])\n",
" pi_s[i] = np.mean(probs[:,i])\n",
" \n",
" history.append(mu_s.copy())\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualization of the means of the gaussians with iterations"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"history = np.array(history)\n",
"for i in range(len(history)):\n",
" fig = plt.figure(figsize=(10,6))\n",
" ax = fig.add_subplot(111)\n",
" ax.set_xlim(-10,10)\n",
" ax.set_ylim(-10,10)\n",
" ax.set_title(\"GMM visualization\")\n",
" ax.scatter(history[i,:,0], history[i,:,1], color='b', s=50)\n",
" ax.legend(\"\")\n",
" fig.savefig(\"plot_{}.png\".format(i))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"from PIL import Image\n",
"pngs = [Image.open(\"plot_{}.png\".format(i)) for i in range(20)]\n",
"\n",
"gif = Image.new(\"RGBA\", (pngs[0].width, pngs[0].height), (255,255,255))\n",
"\n",
"gif.save(fp=\"comb.gif\", format='GIF', append_images=pngs,\n",
" save_all=True, duration=1000, loop=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![test](https://s5.gifyu.com/images/comb.gif)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment