Skip to content

Instantly share code, notes, and snippets.

@andlima
Last active April 19, 2019 02:56
Show Gist options
  • Save andlima/804218eeef71b993946b4c135638a4cb to your computer and use it in GitHub Desktop.
Save andlima/804218eeef71b993946b4c135638a4cb to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%pylab inline\n",
"\n",
"import seaborn as sns\n",
"import pandas as pd\n",
"import scipy.stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Expectation–maximization algorithm\n",
"\n",
"## Definition\n",
"\n",
"In statistics, an expectation–maximization (EM) algorithm is an iterative method to find maximum likelihood or maximum a posteriori (MAP) estimates of parameters in statistical models, where the model depends on unobserved latent variables. The EM iteration alternates between performing an expectation (E) step, which creates a function for the expectation of the log-likelihood evaluated using the current estimate for the parameters, and a maximization (M) step, which computes parameters maximizing the expected log-likelihood found on the E step. These parameter-estimates are then used to determine the distribution of the latent variables in the next E step.\n",
"\n",
"Source: https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm\n",
"\n",
"## This notebook\n",
"\n",
"This notebook implements a version of EM for a dataset with **a single categorical latent variable**, which can be used for clustering. It's important to notice that EM can be used for more than one latent variable and also for numerical latent variables."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Randomized dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def randomized_dataset(n_observations, n_features, n_groups, seed=0):\n",
" '''Generate a random dataset.\n",
"\n",
" - Each observation belongs to a group (a cluster)\n",
" - The group is identified as a categorical variable\n",
" - Each group is a multivariate gaussian distribution with dimension ``n_features``\n",
" - Each observation has features that come from the respective group's distribution\n",
"\n",
" Args:\n",
" n_observations: number of observations in the dataset\n",
" n_groups: number of groups (latent categorical variables)\n",
" n_features: number of features\n",
"\n",
" Returns:\n",
" X (array[n_observations, n_features]): random dataset\n",
" labels (array[n_observations]): identifiers for the groups\n",
" probs (array[n_groups]): probability a random observation belongs to a certain group\n",
" mu (array[n_groups, n_features]): for each group, the average value of each feature\n",
" sd (array[n_groups, n_features]): for each group, the standard deviation of each feature\n",
" '''\n",
"\n",
" random.seed(seed)\n",
"\n",
" probs = exp(randn(n_groups))\n",
" probs = probs / sum(probs)\n",
" assert np.allclose(sum(probs), 1.0)\n",
"\n",
" mu = random.randint(128, 256, size=(n_groups, n_features))\n",
" sd = random.randint(2, 8, size=(n_groups, n_features))\n",
"\n",
" labels = random.choice(range(n_groups), size=n_observations, p=probs)\n",
"\n",
" X = mu[labels] + sd[labels] * random.randn(n_observations, n_features)\n",
" scale_factor = np.max(X, axis=0)\n",
"\n",
" return (X, labels, probs, mu, sd)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Create a dataset with 2 features, so we can view the clusters\n",
"\n",
"X, labels, probs, mu, sd = randomized_dataset(n_observations=1024, n_features=2, n_groups=4, seed=37)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot observations"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def plot_observations(X, labels, title):\n",
" '''Plot observations coloring according to the given labels.'''\n",
"\n",
" df = pd.DataFrame(X, columns=[f'x{k+1}' for k in range(X.shape[1])])\n",
" df['label'] = labels.astype(int)\n",
" for label in set(labels):\n",
" x1, x2 = df.query('label == @label')[['x1', 'x2']].values.T\n",
" scatter(x1, x2, label=label)\n",
" plt.title(title)\n",
" legend()\n",
" show()\n",
"\n",
"plot_observations(X, labels, title='Original clusters')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Algorithm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### E and M steps"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def e_step(X, probs, mu, sd):\n",
" '''Perform the Expectation step (E-step) for the EM-algorithm.\n",
"\n",
" Args:\n",
" X (array[n_observations, n_features]): the feature matrix\n",
" probs (array[n_groups]): the current best estimation for the parameter\n",
" mu (array[n_groups, n_features]): the current best estimation for the parameter\n",
" sd (array[n_groups, n_features]): the current best estimation for the parameter\n",
"\n",
" Returns:\n",
" T (array[n_groups, n_observations]): the probability an observation belongs to a group\n",
" '''\n",
"\n",
" # The probability density function for the multivariate gaussian\n",
" pdf = scipy.stats.multivariate_normal.pdf\n",
"\n",
" T = array([probs_ * pdf(X, mean=mu_, cov=sd_)\n",
" for (probs_, mu_, sd_) in zip(probs, mu, sd)])\n",
" T /= sum(T, axis=0)\n",
"\n",
" assert np.allclose(T.sum(axis=0), 1.0)\n",
"\n",
" return T\n",
"\n",
"\n",
"def m_step(T, X, probs, mu, sd):\n",
" '''Perform the Maximization step (M-step) for the EM-algorithm.\n",
" \n",
" Update ``probs``, ``mu`` and ``sd`` inplace based on ``T`` and ``X``.\n",
"\n",
" Args:\n",
" T (array[n_groups, n_observations]): the probability an observation belongs to a group\n",
" X (array[n_observations, n_features]): the feature matrix\n",
" probs (array[n_groups]): to be updated to the next best estimation\n",
" mu (array[n_groups, n_features]): to be updated to the next best estimation\n",
" sd (array[n_groups, n_features]): to be updated to the next best estimation\n",
" '''\n",
"\n",
" for j in range(len(probs)):\n",
" probs[j] = mean(T[j], axis=0)\n",
" mu[j] = sum(T[j].reshape(-1, 1) * X, axis=0) / sum(T[j])\n",
" sd[j] = sum(T[j].reshape(-1, 1) * (X - mu[j]) * (X - mu[j]), axis=0) / sum(T[j])\n",
"\n",
" assert np.allclose(sum(probs), 1.0)\n",
" return (probs, mu, sd)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialization, Iteration and Post-processing"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Initialize the variables for parameter estimation\n",
"# probs_hat: 1/n_groups for each group \n",
"probs_hat = ones(probs.shape) / len(probs)\n",
"mu_hat = random.random(mu.shape)\n",
"sd_hat = ones(sd.shape)\n",
"assert np.allclose(sum(probs_hat), 1.0)\n",
"\n",
"# Scale X to avoid numerical problems\n",
"scale_factor = np.max(X, axis=0)\n",
"XX = X / scale_factor\n",
"assert np.allclose(XX.max(), 1.0)\n",
"\n",
"# Iterate alternating between E-step and M-step\n",
"for iteration in range(1000):\n",
" T = e_step(XX, probs_hat, mu_hat, sd_hat)\n",
" m_step(T, XX, probs_hat, mu_hat, sd_hat)\n",
"\n",
"# Post-processing: scaling back mu_hat and sd_hat\n",
"mu_hat = mu_hat * scale_factor\n",
"sd_hat = (sd_hat * (scale_factor ** 2)) ** 0.5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Results"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"* Original parameter values (ground truth)\n"
]
},
{
"data": {
"text/plain": [
"array([0.20602828, 0.42700046, 0.30769962, 0.05927164])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([[163, 195],\n",
" [170, 213],\n",
" [191, 216],\n",
" [245, 215]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([[2, 2],\n",
" [5, 4],\n",
" [6, 2],\n",
" [6, 4]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"* Estimated parameter values\n"
]
},
{
"data": {
"text/plain": [
"array([0.05371094, 0.32170901, 0.20373055, 0.4208495 ])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([[245.59807655, 215.79282986],\n",
" [191.2382189 , 216.08071049],\n",
" [163.09387915, 195.16652376],\n",
" [170.57574173, 212.59164509]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([[6.81597025, 3.41986827],\n",
" [6.09132375, 1.94990394],\n",
" [1.99913102, 1.9359254 ],\n",
" [4.93429354, 4.15610579]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Note: the order of the groups can be different\n"
]
}
],
"source": [
"print('* Original parameter values (ground truth)')\n",
"display(probs)\n",
"display(mu)\n",
"display(sd)\n",
"\n",
"print()\n",
"print()\n",
"print('* Estimated parameter values')\n",
"display(probs_hat)\n",
"display(mu_hat)\n",
"display(sd_hat)\n",
"\n",
"print()\n",
"print()\n",
"print('Note: the order of the groups can be different')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# The predicted label is the one with the largest probability\n",
"labels_hat = array(T).argmax(axis=0)\n",
"\n",
"df = pd.DataFrame(list(zip(labels, labels_hat)), columns=['label', 'label_hat'])\n",
"count_df = df.assign(Q=1).groupby(['label', 'label_hat'])['Q'].count()\n",
"unstacked = count_df.unstack().fillna(0.)\n",
"unstacked = unstacked / unstacked.sum()\n",
"sns.heatmap(unstacked, annot=True, fmt='.1%');"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_observations(X, labels, title='Original clusters')\n",
"plot_observations(X, labels_hat, title='Predicted clusters')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment