Created
May 25, 2020 11:55
-
-
Save Janglee123/1f5f6cd3eee69f99dd4c8465445887e2 to your computer and use it in GitHub Desktop.
EM Algo.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "EM Algo.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyOm4E7DK4w7qh41uyZV8gE6", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/Janglee123/1f5f6cd3eee69f99dd4c8465445887e2/em-algo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QnRLy86tlM7E", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt \n", | |
"import copy\n", | |
"from scipy.stats import multivariate_normal" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KWYWOMsMdY8b", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def log_sum_exp(Z):\n", | |
" return np.max(Z) + np.log(np.sum(np.exp(Z - np.max(Z))))" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "GYtMreqYdcLm", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def loglikelihood(data, weights, means, covs):\n", | |
" num_clusters = len(means)\n", | |
" num_dim = len(data[0])\n", | |
" \n", | |
" ll = 0\n", | |
" for d in data:\n", | |
" \n", | |
" Z = np.zeros(num_clusters)\n", | |
" for k in range(num_clusters):\n", | |
" \n", | |
" delta = np.array(d) - means[k]\n", | |
" # covs[k] = covs[k] + 1 \n", | |
" exponent_term = np.dot(delta.T, np.dot(np.linalg.pinv(covs[k]), delta))\n", | |
" \n", | |
" Z[k] += np.log(weights[k])\n", | |
" Z[k] -= 1/2. * (num_dim * np.log(2*np.pi) + np.log(np.linalg.det(covs[k])) + exponent_term)\n", | |
" \n", | |
" ll += log_sum_exp(Z) \n", | |
" return ll" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "7lZrctwadf9F", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def EM(data, init_means, init_covariances, init_weights, maxiter=1000, thresh=1e-4):\n", | |
" \n", | |
" means = init_means[:]\n", | |
" covariances = init_covariances[:]\n", | |
" weights = init_weights[:]\n", | |
" \n", | |
" num_data = len(data)\n", | |
" num_dim = len(data[0])\n", | |
" num_clusters = len(means)\n", | |
" \n", | |
" resp = np.zeros((num_data, num_clusters))\n", | |
" ll = loglikelihood(data, weights, means, covariances)\n", | |
" ll_trace = [ll]\n", | |
" \n", | |
" for i in range(maxiter):\n", | |
" \n", | |
" # E-step: compute responsibilities\n", | |
" for j in range(num_data):\n", | |
" for k in range(num_clusters):\n", | |
" resp[j, k] = weights[k]*multivariate_normal.pdf(data[j],means[k],covariances[k], allow_singular=True)\n", | |
" row_sums = resp.sum(axis=1)[:, np.newaxis]\n", | |
" resp = resp / row_sums # normalize over all possible cluster assignments\n", | |
"\n", | |
" # M-step\n", | |
" counts = np.sum(resp, axis=0)\n", | |
" \n", | |
" for k in range(num_clusters):\n", | |
" \n", | |
" weights[k] = counts[k]/num_data\n", | |
" weighted_sum = 0\n", | |
" for j in range(num_data):\n", | |
" weighted_sum += (resp[j,k]*data[j])\n", | |
" means[k] = weighted_sum/counts[k]\n", | |
" \n", | |
" weighted_sum = np.zeros((num_dim, num_dim))\n", | |
" for j in range(num_data):\n", | |
" weighted_sum += (resp[j,k]*np.outer(data[j]-means[k],data[j]-means[k]))\n", | |
" covariances[k] = weighted_sum/counts[k]\n", | |
" \n", | |
" ll_latest = loglikelihood(data, weights, means, covariances)\n", | |
" ll_trace.append(ll_latest)\n", | |
" \n", | |
" if (ll_latest - ll) < thresh and ll_latest > -np.inf:\n", | |
" break\n", | |
" ll = ll_latest\n", | |
"\n", | |
" return resp\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "BaPDeTN6kH9j", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Use sklearn to create embedings from sentence \n", | |
"from sklearn.feature_extraction.text import TfidfVectorizer\n", | |
"\n", | |
"documents = [\"Epidemic spread dengue\",\n", | |
" \"malraia spread\",\n", | |
" \"Epidemic vaccine\",\n", | |
" \"Epidemic spread fast\",\n", | |
" \"Computers are amazing\",\n", | |
" \"MacBook is a great computer\",\n", | |
" \"MacBook is expensive\",\n", | |
" \"Best cat photo I've ever taken.\",\n", | |
" \"google is amazing\",\n", | |
" \"Impressed with google map feedback.\",\n", | |
" \"Using google pixel is amazing\",\n", | |
" \"Google takes too much data\",\n", | |
" \"Google is bad\",\n", | |
" \"MacBook is an Apple product\"]\n", | |
"\n", | |
"vectorizer = TfidfVectorizer()\n", | |
"X = vectorizer.fit_transform(documents)\n", | |
"\n", | |
"X = X.toarray()\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gT2ZKPkkqLPW", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"chosen = np.random.choice(len(X), 3, replace=False)\n", | |
"initial_means = [X[x] for x in chosen]\n", | |
"initial_covs = [np.cov(X, rowvar=0) + np.eye(35) * 1e-4] * 3 \n", | |
"\n", | |
"initial_weights = [1/3.] * 3\n", | |
"# Use self defined EM method\n", | |
"results = EM(X, initial_means, initial_covs, initial_weights, maxiter=200)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_65fzbuC6dVt", | |
"colab_type": "code", | |
"outputId": "1a761a8e-0f14-4a2f-85ea-783a8ec98b09", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 141 | |
} | |
}, | |
"source": [ | |
"res = results\n", | |
"culster = [-1]*100\n", | |
"for i in range(len(res)):\n", | |
" culster[i] = np.where(res[i] == np.amax(res[i]))[0][0]\n", | |
"\n", | |
"data_ = [[], [], []]\n", | |
"\n", | |
"for i in range(len(documents)):\n", | |
" data_[culster[i]].append(documents[i])\n", | |
"\n", | |
"for i in range(len(data_)):\n", | |
" print(\"Group\", i + 1)\n", | |
" print(data_[i])" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Group 1\n", | |
"['Epidemic spread dengue', 'malraia spread', 'Epidemic vaccine', 'Epidemic spread fast', \"Best cat photo I've ever taken.\"]\n", | |
"Group 2\n", | |
"['MacBook is a great computer', 'MacBook is expensive', 'Impressed with google map feedback.', 'Google takes too much data', 'MacBook is an Apple product']\n", | |
"Group 3\n", | |
"['Computers are amazing', 'google is amazing', 'Using google pixel is amazing', 'Google is bad']\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment