Last active
November 18, 2019 08:00
-
-
Save piyush01123/81d7e05b5b32e41fb2e83c2d89e68722 to your computer and use it in GitHub Desktop.
Gaussian Mixture Model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"import scipy.stats as stats\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Simulating Points \n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.collections.PathCollection at 0x1188a4f50>" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD8CAYAAABq6S8VAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAX9ElEQVR4nO3dX4wkV3XH8d/BXkT4EzvsrmJie2OjxJEWg0V2zHqDSACjzEBsNlrlARAIwsMmKBBIiCzAUkbz4iBAJEiQMKPgiAgnhJglyIh4bSOEFM2w0GsMsbFBFjHGDsiDIwsEEY7FyUN1a3vb/aeq7q2qe6u/H2k0Mz3Tt251VZ06fe6tanN3AQDy9ZSuOwAACEMgB4DMEcgBIHMEcgDIHIEcADJHIAeAzEUJ5Gb2Z2Z2j5ndbWb/bGZPi9EuAGCx4EBuZhdK+lNJK+5+uaRzJL0mtF0AQDmxSivnSvoFMztX0tMl/XekdgEAC5wb2oC7P2xmH5D0oKT/lXSbu9827zn79u3zSy65JHTRALBUTp8+/UN33z/5eHAgN7NfknRU0qWSHpP0r2b2enf/xMT/HZd0XJIOHDigwWAQumgAWCpm9t1pj8corbxC0n+5+667/5+kE5J+a/Kf3H3L3VfcfWX//iedUAAANcUI5A9KusrMnm5mJulqSfdGaBcAUEJwIHf3U5JulnSnpP8ctrkV2i4AoJzgGrkkufu6pPUYbQEAquHKTgDIHIEcADJHIAeAzBHIAXRqZ0daWyu+o54og50AUNfGhnTyZPHzrbd225dcEcgBdGp9/ezvqI5ADqBTR46QiYeiRg4AmSOQA0DmCOQAkDkCOQBkjkAOAJkjkANA5gjkAJA5AjkAZI5ADgCZI5ADQOYI5ACQOQI5AAzlektdbpoFAEO53lI3SkZuZueb2c1mdp+Z3WtmR2K0CwAhqmbY6+vS6mp+t9SNlZF/SNKt7v4HZvZUSU+P1C4A1FY1w871lrrBGbmZnSfptyV9TJLc/XF3fyy0XQCYp0y2nWuGXVWMjPxSSbuS/sHMrpB0WtLb3f0nEdoGgKnKZNu5ZthVxaiRnyvpNyX9nbu/UNJPJL1r8p/M7LiZDcxssLu7G2GxAJbFtOx7WbLtMmIE8ockPeTup4a/36wisJ/F3bfcfcXdV/bv3x9hsQBytrUl7dtXfJfml0pG2ffGxpnHRtn2EaZWhJdW3P0HZvY9M/sNd/+WpKslfTO8awD67D3vkR59tPh+/Pj8Ugkf0DxfrFkrb5N003DGynck/WGkdgH01A03FEH8hhuK3+cF62Wpdddl7t76QldWVnwwGLS+XABo0s5O8c5ifb2Zko+ZnXb3lcnHubITACLp6spQ7rUCLLlc7y/SlJDXo6uZNGTkwJLL9f4iTQl5Pbqq5fcmIyerAOppM4vM4TjNcX56bwY719aKs+jqKlkFkCqO0zC9H+xknimQPo7TZvSmtLJMV3nl8PYUmKat4zSFY6TNPvQmkC+TaZcrpy6FAwvx7exIV11VfKW0bVM4RtrsQ29KK11oevL/LDm+PWVmRD9tbEinTp35OZVtm8Ix0mof3L31r0OHDnkfrK66S8V3zLe9XbxO29td9yQNOb8e433f3nY/fLj4ynFdciNp4FNiKoE8QFMHY84HOcrJOQlY1Pe+7r8prBeBPCOHDxdb5vDhrnsyXwo7dq5ivnZtb4dFyytzkspx30nh5Esgr6DrnSyXQJ7Cjo3w7RB7fy/TXo77TtdxwZ1AXknXO1kKO0wZufSz70K3Qxf7O/tOPQTyClLbyXJ+G470pbZPpNSfun1pah0I5Jna3nbfuzdextT1uw1gkbr76GTwrBpMp/1/3b40dZzNCuTMI0/cxkbxcVh798aZj5rC/FrkYfw6Cam9aybq7qOjaxUee0w6//zi+2iOe5n57dOudTh2TBoMiu9VtH6cTYvuTX81mZGn9LbMPbw/055PqQVtGM8qRz/v3dv9vjJrnx09Pj5ZoMoc98n58eNtpfIOVstSWkmtdNBEf2K2mdrr1bRlP3FVWf/x/93cdN+zp9q+0tRrPb7PLkp0Qksjhw+ntb8sTSBfdLZue4M0sVwy8vr6dOKqs+1CA1uVjLzKskbrsrm5eJ2qBOrUBitDLU0gn6VPBzDqS/UAraPOPh07sM1rr8qyxk8UKWT9qWo8kEs6R9LXJH1u0f92EciXbYOj/1LYp2MlSFUy8mXWRiD/c0n/lGog70JOJZA67acQSNCtFPeB1MuZIRoN5JIukvQFSS8nkJ+R06BknfZTuzQcT5ZiAhDa7qJlpjbBIOZr1HQgv1nSIUkvnRXIJR2XNJA0OHDgQPgaZSCljHzR87vIyBm3iCvmBS1lTbYfa7rsrH6XuUAuVt1+3vPq1P9jbIPGArmkayT97fDnmYF8/CvFjLxPs1qmSTFoTpu3S3Ze37RtHOt1LTvYOW12S51+zfp7ndkz054fehzUmZGTdEYu6a8kPSTpAUk/kPRTSZ+Y95wUA3lXgS72YFHMjLtNKZ5oclNmG9fdD8pun2kZ8+Zm8djmZvX2Jvvc9DvTttupqpXph2TkzS93UaaSayBM/UTTF3X3kyrbZ1aWXvedQu77dkwE8qFUAkbZfpQ5KKq016QU+oD5uthGdevLTEl8sqW/IGgklbN72X5MfshEysEyldcWzQsZ1Cy7n8z7v1iDqrlZikDeZI0wtrL9SOHTguq+e0C3mtwe04JslTp66P4UsvycLUUg72JDNh286kylitGn8TaW4QDpoya3W5sZcdlldXUstpnALEUg7yIj7DLIzVp2jD6Nt0Gmnac2t1vb2X8b61Z2fKrNGLAUgbxpTWQBdZ6/aBAoRomJ4I0q2s7+2yg5Tq4TGXlPAnnsGmCVNkOfU6cNgnk6Ut8WdUqAIdoI5DH7HastAnkEZTPd0QURZa4+q5M9x66Bz0JtPB25bItp+1UTfU/9xDZp9BqEfsISgbwlow1W9dNUyrTZ9kGc28HSZ13UhOvoqp6duvEEL+Q4JpC3YHv7zGcExryIoeqBwIGDOmIkDDlNAW5bjPUmkLegyoHQZP1tWj82N93PO8/94EEOsq6k/tq21b9cykQpIpC3oKlBzkXtlhldH72tW7RMDrLm5P7aVp3plMIsjzpS7h+BPIKuRrHHR+jrToHc3HR/xjOKr+uum32AjUpDKe7EuUs5QJQx60Q0Wq/Rfjr6e64nrpT7TSCPYNEGbqqWPR7Iy85tndf/WYMuKe/A6N6sfW2034z2z5DZVSmc7FLowyy9CORdv8BVSxyLlB3hH39sMmuuU6IJuZAImBRzvyGZmK8XgTyVjVy19lfl8TJZ/3hGneJFC0BdbV0jkateBPJUNlBI5l11wGhWW6EXFizqJxBbk8fveNt93o97Echjq7tjhWTes+qJTfW5yv+ncqJEPzUZYCeTpa4H7Zs6lgjkU8Tescq0N22Ev6mNPlmGAbrUVkbu3n1W3tTyCeRTxN6x6ma/TW30yTJMGxk3WX36lmEbdb2OZOQdaDNTaPJ/Fj2njSyl60wIi7GN8kUgn6Ot2l3X7ZCRw51t1KbYr/WsQH6uApnZxZL+UdIvS3JJW+7+odB227S+fvb3FNuO0c6RI9Ktt4b1I4VlIAzbqD0bG9LJk8XPTb7mVgT5gAbMniPpOe5+p5k9S9JpSb/v7t+c9ZyVlRUfDAZBy+2DnZ1iQ6+vFwcXgH6JfYyb2Wl3X5l8PDgjd/fvS/r+8Ocfm9m9ki6UNDOQo9jA114rPfpo8TsZEtA/bb37eUrMxszsEkkvlHRqyt+Om9nAzAa7u7sxFxvdzo60tlZ8b8rGRhHE9+5tpqQDYHlEC+Rm9kxJn5b0Dnf/0eTf3X3L3VfcfWX//v2xFtuIUV1rY6O5ZayvS6ur0i23NFtWaeOkBKBbwaUVSTKzPSqC+E3ufiJGm20br2U1Ofg50tZbrrYGWwB0J8asFZP0MUn3uvsHw7vUjcmA15eg18ZJCUC3YpRWXizpDZJebmZ3Db9eFaHdVo1KHX0LeKPMn1kxWBbLWE6MMWvlPyRZhL50Knapg6mFQDeWsZwYddbKMll01m9jwDS2Zcxk0D99fXc9T5TBzmW06KyfY216GTMZ9M8yXrlKIK9pUaDucmeqW9bJ8eQDIMIl+nVwiX6z1taKzHp1dfkyE6DPZl2iT408USH16mWsEQLLjEDegBiDhiGDpUw5BJYLgbwBMWaspJ5VM8MFSAeDnQ3I5d7hIZjhAqSDQN6A1INwDMxwAdJBIEcty3CyAnJBjTwzTdamqXsDeSIjz0yTtWnq3kCeCOSZyeGDogG0iys7ASATXNkJAD1FIAeAzBHIASBzBHIAyByBHAAyRyAHgMxFCeRmtmZm3zKz+83sXTHaBACUExzIzewcSR+R9EpJByW91swOhrYLACgnRkb+Ikn3u/t33P1xSZ+UdDRCuwCAEmIE8gslfW/s94eGj53FzI6b2cDMBru7uxEWCwCQWhzsdPctd19x95X9+/e3tVgswB0PgfzFuGnWw5IuHvv9ouFjyAB3PATyFyOQf1XSr5vZpSoC+GskvS5Cu2gBdzwE8hccyN39CTN7q6STks6RdKO73xPcM7SCT/oB8hflfuTu/nlJn4/RFgCgGq7sBIDMEcgBIHMEcgDIHIEcADJHIAeAzBHIASBzBHIAyByBHAAyRyAHgMwRyAEgcwRyAMgcgRxoCPd6R1ui3DQLwJNxr3e0hUAONIR7vaMtBHKgIdzrHW2hRg4AmSOQA0DmCOQAkDkCOQBkjkAOAJkLCuRm9n4zu8/MvmFmnzGz82N1DABQTmhGfruky939BZK+Lend4V0CEANXli6PoEDu7re5+xPDX78s6aLwLgH9tSi4xgy+oytLNzaaaR/piHlB0Jsl/UvE9oDeWXTZfszL+qddWcptA/ppYSA3szskXTDlT9e7+2eH/3O9pCck3TSnneOSjkvSgQMHanUWyN2xY9JgUHyfJuZl/dOuLOW2Af1k7h7WgNmbJP2RpKvd/adlnrOysuKDwSBouUDXdnaKDHd9vQiaZaytFRnx6ioZMaozs9PuvjL5eOislTVJ10l6ddkgDvTFtBr0IuvrRRCfzIipXSNE6KyVD0t6lqTbzewuM/tohD4BWZgVlOcZlTsmM/iqJ4WmA/+09jnZpCtosNPdfy1WR4DcxLy7YdXadd1By7LloGntM1CaLm5jC7RoViCtelKoO2hZNhhPa5+B0nQFD3bWwWAnllUXg53jJw+p+gAt0tHIYCfQhj7VZuvU1UON199n1ejL6tO26BNKK0hen2qzXXxqUMySSJ+2RZ8QyJG8ZajN1pmTXlaXg7JoBzVyIAGTtfMmAzvyRY0cSMx4vXmydj45r5zaNOYhkAMljAfSWEF13iDkrMB+7bXS1lb7FwMhce7e+tehQ4ccyMnqqrtUfB//eXv7zPeqqjx3e9t9795iuaPvq6vVl1nG+PohLZIGPiWmMtgJlDDrApmQWRxVBiGPHJFuuaVY3rFj0okTzQ04MqCZHwY7gQBND0qGtM+Aaf8w2Ak0IPQCm0Xq3GGx7nPbrI1Th4+L0gqQsDpljlEmPvrwiqZvxFUHFxbFRUYOtGBrS9q3r/gulc9I62T8oyB54sT0585adpu3D+jiVgW9Nm0EtOkvZq1g2YzPOHFvdmbIotkwTc9K2d52P3y4+KozmydlIbOUYtCMWSsEcmBMjAN1Whubm0UQ39yMt5y6/Smz7JD+jU4UfZzC2PXUTAI5UEKMA7Xrg308CG9uuu/ZU70/IetARt4cAjlQQlMZedW2Y2TEq6tnSjp79ixua3yZXQcsTEcgBxpWtzY97XmhGfF4Rj5e0pmn63cSWIxAjqXRVTa5KBDO6te05zWxDovaJAtPX6OBXNI7JbmkfWX+n0COJnWVWdYNhFWfV3c5ZNz5ayyQS7pY0klJ3yWQIwWxMsumZ3fUVTcgk3Hnr8lAfrOkKyQ9QCBHn5QJmLGz3Jgnjz4E7hgDx33SSCCXdFTSh4Y/E8jRK11k5DFPDDmWUiZfz1nrkOO6xVA7kEu6Q9LdU76OSjol6TwvEcglHZc0kDQ4cOBAu2sPZKLOiaGrrLWJ9icDNBn52aJn5JKeL+mRYQB/QNITkh6UdMGi55KRIyepBI0qs17aWHaV5S5TOahJjU8/pLSCPog9p7uNfrQR/EKnSC5rKSS2WYGc29gCY6bdXrWLT8yp0o8qnzRU12iZx44Vd04cfVhF2eXyqUMNmxbdm/4iI0eqmrxEv+02mkBm3S2RkQOLxchuY3xoQhtZdh1k1mnigyWAyKZ9aEJfPtqs6Y+2Qz0EciCyacEu5LM3R7o4GZRdZl9OVLkikAM1VA1cMT7aLMbJYNKi9Si7zMn/I7C3bFrhvOkvBjuRuy4G/arcknae8YHUundsXPR/sW7Di7OJ29gC8eR0s6x57TS1HrE+GCN227kjkAOZa/Oujk2b14e6H9CxDAjkQEUpBLx5Uu/fPJRe6pkVyBnsBGZoYnAxpjr9izUIGdpOyOAvUyCnmBbdm/4iI0cOUsz8Qj8guezdBau2g3aI0gpQXWrBfF4ALVN33tyMM7sktddlWcwK5FyiD8wR43L7mOZdIj+vr7P+VveS+1RvIbCsqJEDc8S4kGeaeTXmeX+bVx+e19em1gNpsCJbb9fKyooPBoPWlwukYm2tyJBXV5+c2c77W5v9QHrM7LS7r0w+TmkF6MC8kkabdxjkbob9QEYOAJmYlZFTI0evLNPNmnJc1xz7nANKK+iFnZ1iZsZjj0mnThWP9b3mm9qMmjJy7HMOCOTohVGAOHx4eWZn5FjfzrHPOSCQoxfGA8SyXLqd41zuHPucAwI5eoEAgWUWPNhpZm8zs/vM7B4ze1+MTgEAygsK5Gb2MklHJV3h7s+T9IEovQLQO8xYaU5oaeUtkt7r7j+TJHd/JLxLAPqIGSvNCS2tXCbpJWZ2ysy+ZGZXzvpHMztuZgMzG+zu7gYuFkBuuN9LcxZm5GZ2h6QLpvzp+uHzny3pKklXSvqUmT3Xp1wu6u5bkrak4srOkE4DyA8D0s1ZmJG7+yvc/fIpX5+V9JCkE8Nb5X5F0s8l7Wu60wCoOeOM0Br5v0l6maQvmtllkp4q6YfBvQKwEDVnjIQG8hsl3Whmd0t6XNIbp5VVAMTHVZIYCQrk7v64pNdH6guACqg5Y4S7HwJA5gjkAJA5AjkAZI5ADgCZI5ADQOYI5ACQOQI5AGTOurh+x8x2JX239QXHt0/9vJKV9cpLH9erj+skha/Xr7r7/skHOwnkfWFmA3df6bofsbFeeenjevVxnaTm1ovSCgBkjkAOAJkjkIfZ6roDDWG98tLH9erjOkkNrRc1cgDIHBk5AGSOQB6Bmb3NzO4zs3vM7H1d9ycmM3unmbmZZf/JT2b2/uF2+oaZfcbMzu+6TyHMbM3MvmVm95vZu7ruTwxmdrGZfdHMvjk8nt7edZ9iMrNzzOxrZva5mO0SyAOZ2cskHZV0hbs/T9IHOu5SNGZ2saTflfRg132J5HZJl7v7CyR9W9K7O+5PbWZ2jqSPSHqlpIOSXmtmB7vtVRRPSHqnux9U8VnAf9KT9Rp5u6R7YzdKIA/3FknvdfefSZK7P9Jxf2L6a0nXSerFQIq73+buTwx//bKki7rsT6AXSbrf3b8z/ICXT6pIKLLm7t939zuHP/9YRdC7sNtexWFmF0n6PUl/H7ttAnm4yyS9xMxOmdmXzOzKrjsUg5kdlfSwu3+967405M2S/r3rTgS4UNL3xn5/SD0JeCNmdomkF0o61W1PovkbFYnRz2M3HPqZnUvBzO6QdMGUP12v4jV8toq3gVdK+pSZPTeHzy5dsF7vUVFWycq8dXL3zw7/53oVb+FvarNvKM/Mninp05Le4e4/6ro/oczsGkmPuPtpM3tp7PYJ5CW4+ytm/c3M3iLpxDBwf8XMfq7ifgq7bfWvrlnrZWbPl3SppK+bmVSUIO40sxe5+w9a7GJl87aVJJnZmyRdI+nqHE62czws6eKx3y8aPpY9M9ujIojf5O4nuu5PJC+W9Goze5Wkp0n6RTP7hLtH+cxj5pEHMrM/lvQr7v6XZnaZpC9IOpB5kDiLmT0gacXds76JkZmtSfqgpN9x9+RPtPOY2bkqBmyvVhHAvyrpde5+T6cdC2RF5vBxSf/j7u/ouj9NGGbkf+Hu18Rqkxp5uBslPdfM7lYx4PTGPgXxnvmwpGdJut3M7jKzj3bdobqGg7ZvlXRSxYDgp3IP4kMvlvQGSS8fbqO7hlks5iAjB4DMkZEDQOYI5ACQOQI5AGSOQA4AmSOQA0DmCOQAkDkCOQBkjkAOAJn7f3I1cWhLLP7HAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"n_points = 100\n", | |
"\n", | |
"actual_means = [[2, 4], [-5, 3], [0, -4]]\n", | |
"actual_cov_mats = [[[1, 0], [0, 1]], [[1, 0], [0, 1]], [[1, 0], [0, 1]]]\n", | |
"\n", | |
"blobs = [np.random.multivariate_normal(m, s, (n_points,)) for m, s in zip(actual_means, actual_cov_mats)]\n", | |
"X = np.concatenate(blobs)\n", | |
"\n", | |
"plt.scatter(X[:,0], X[:, 1], color='b', s=2)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Gaussian Mixture Model \n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"In Gaussian Mixture Model or GMM, we assume that our data distribution is given by K gaussians and say that \n", | |
"\n", | |
"\\begin{equation}\n", | |
"P(X) = \\sum_{k=1}^K \\pi_k P(X | \\mu_k, \\sigma_k)\n", | |
"\\end{equation}\n", | |
"\n", | |
"where $P(X | \\mu, \\sigma)$ is the normal probability distribution given by\n", | |
"\\begin{equation}\n", | |
"P(X | \\mu, \\sigma) = \\frac {1}{\\sqrt{2\\pi} \\sigma} e^{\\frac{-{(X-\\mu)}^2}{2\\sigma^2}}\n", | |
"\\end{equation}\n", | |
"\n", | |
"which in higher than 1-D translates to \n", | |
"\\begin{equation}\n", | |
"P(X | \\mu, \\Sigma) = \\frac {1}{\\sqrt{2\\pi} \\Sigma^{\\frac 1 2}} e^{\\frac{-{(X-\\mu)}^T \\Sigma^{-1} (X-\\mu)}{2}}\n", | |
"\\end{equation}\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Our EM algorithm \n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Assume K as some positive integer (we can assume 3 in our case).\n", | |
"\n", | |
"Step 1: Initialize means $\\mu_k$, covariance matrices $\\Sigma_k$ and weights $\\pi_k$ for k=1 to K.\n", | |
"\n", | |
"### E step:\n", | |
"\n", | |
"Step 2: Calculate $P(X_i|\\mu_k, \\Sigma_k)$ for each point $X_i$ for each gaussian. This is our familiar multivariate normal distribution p.d.f.\n", | |
"\\begin{equation}\n", | |
"P(X_i | \\mu_k, \\Sigma_k) = \\frac {1}{\\sqrt{2\\pi} \\Sigma_k^{\\frac 1 2}} e^{\\frac{-{(X_i-\\mu_k)}^T \\Sigma_k^{-1} (X_i-\\mu_k)}{2}}\n", | |
"\\end{equation}\n", | |
"\n", | |
"This is a matrix of size NxK where N is the number of points and K is the assumed number of gaussians. \n", | |
"\n", | |
"Step 3: Normalize this matrix so that each row sums up to 1. Each element $P_{ik}$ in this matrix represents probability of $i^{th}$ point being associated to $k^{th}$ gaussian.\n", | |
"\n", | |
"\n", | |
"### M step:\n", | |
"\n", | |
"Step 4: Re-calculate the parameters of the model ie $\\mu_k$, $\\Sigma_k$ and $\\pi_k$ for k=1 to K. We can do this as a calculation of weighted means, weighted covariance matrices and $\\pi_k$s are simply the normalized sum of each column in the $P$ matrix calculated in the E step.\n", | |
"\n", | |
"\\begin{equation}\n", | |
"\\mu_k = \\frac{\\sum_{i=1}^N P_{ik}X_i}{\\sum_{i=1}^N P_{ik}}\n", | |
"\\end{equation}\n", | |
"\n", | |
"\\begin{equation}\n", | |
"\\Sigma_k = \\frac {\\sum_{i=1}^N P_{ik}{(X_i-\\mu_k)}^T(X_i-\\mu_k)}{\\sum_{i=1}^N P_{ik}}\n", | |
"\\end{equation}\n", | |
"\n", | |
"\\begin{equation}\n", | |
"\\pi_k = \\frac {1}{N} {\\sum_{i=1}^N P_{ik}}\n", | |
"\\end{equation}\n", | |
"\n", | |
"\n", | |
"##### Repeat steps 2 to 4 till a convergence criteria is met or for a fixed number of iterations\n", | |
"In our case we simply repeat for 100 iterations." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"pi_s = [1/3]*3\n", | |
"mu_s = [[-5,-5], [0,-5], [0,5]]\n", | |
"sigma_s = [[[1,0],[0,1]], [[1,0],[0,1]], [[1,0],[0,1]]]\n", | |
"\n", | |
"history = []\n", | |
"history.append(mu_s.copy())\n", | |
"\n", | |
"\n", | |
"n_iterations = 100\n", | |
"for itr in range(n_iterations):\n", | |
" print(\"Mu\", mu_s)\n", | |
" print(\"SIGMA\", sigma_s)\n", | |
" print(\"PI\", pi_s)\n", | |
"\n", | |
" # E Step\n", | |
" # In E step we are recalculating the probabilities of each point being associated with each gaussian\n", | |
" # Notice that we can compare this step to the step in K-Means where we are assigning each point to a cluster except that here it is a soft assignment\n", | |
" rand_vars = [stats.multivariate_normal(mu, sigma) for mu, sigma in zip(mu_s, sigma_s)]\n", | |
" gamma_mat = np.array([[pi*rand_var.pdf([x]) for pi, rand_var in zip(pi_s, rand_vars)] for x in X] ) #3x100\n", | |
" probs = gamma_mat/np.sum(gamma_mat, axis=1).reshape((300,1))\n", | |
"# print(\"PROBS\\n\", probs)\n", | |
"\n", | |
" # M step\n", | |
" # In M step we are re-calculating the parameters ie the mu_s, the sigma_s and the pi_s\n", | |
" # We can compare this to re-calculating the centroids in the K-Means algorithm\n", | |
" for i, (mu, sigma) in enumerate(zip(mu_s, sigma_s)):\n", | |
" mu_s[i] = np.sum(probs[:,i].reshape((300,1)) * X, axis=0) / np.sum(probs[:,i])\n", | |
" sigma_s[i] = (X-mu_s[i]).T.dot((X-mu_s[i])*probs[:,i].reshape((300,1))) / np.sum(probs[:,i])\n", | |
" pi_s[i] = np.mean(probs[:,i])\n", | |
" \n", | |
" history.append(mu_s.copy())\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Visualization of the means of the gaussians with iterations" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"history = np.array(history)\n", | |
"for i in range(len(history)):\n", | |
" fig = plt.figure(figsize=(10,6))\n", | |
" ax = fig.add_subplot(111)\n", | |
" ax.set_xlim(-10,10)\n", | |
" ax.set_ylim(-10,10)\n", | |
" ax.set_title(\"GMM visualization\")\n", | |
" ax.scatter(history[i,:,0], history[i,:,1], color='b', s=50)\n", | |
" ax.legend(\"\")\n", | |
" fig.savefig(\"plot_{}.png\".format(i))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": true, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from PIL import Image\n", | |
"pngs = [Image.open(\"plot_{}.png\".format(i)) for i in range(20)]\n", | |
"\n", | |
"gif = Image.new(\"RGBA\", (pngs[0].width, pngs[0].height), (255,255,255))\n", | |
"\n", | |
"gif.save(fp=\"comb.gif\", format='GIF', append_images=pngs,\n", | |
" save_all=True, duration=1000, loop=0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment