Skip to content

Instantly share code, notes, and snippets.

@Janglee123
Created May 25, 2020 11:55
Show Gist options
  • Save Janglee123/1f5f6cd3eee69f99dd4c8465445887e2 to your computer and use it in GitHub Desktop.
Save Janglee123/1f5f6cd3eee69f99dd4c8465445887e2 to your computer and use it in GitHub Desktop.
EM Algo.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "EM Algo.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOm4E7DK4w7qh41uyZV8gE6",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Janglee123/1f5f6cd3eee69f99dd4c8465445887e2/em-algo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "QnRLy86tlM7E",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt \n",
"import copy\n",
"from scipy.stats import multivariate_normal"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KWYWOMsMdY8b",
"colab_type": "code",
"colab": {}
},
"source": [
"def log_sum_exp(Z):\n",
" return np.max(Z) + np.log(np.sum(np.exp(Z - np.max(Z))))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "GYtMreqYdcLm",
"colab_type": "code",
"colab": {}
},
"source": [
"def loglikelihood(data, weights, means, covs):\n",
" num_clusters = len(means)\n",
" num_dim = len(data[0])\n",
" \n",
" ll = 0\n",
" for d in data:\n",
" \n",
" Z = np.zeros(num_clusters)\n",
" for k in range(num_clusters):\n",
" \n",
" delta = np.array(d) - means[k]\n",
" # covs[k] = covs[k] + 1 \n",
" exponent_term = np.dot(delta.T, np.dot(np.linalg.pinv(covs[k]), delta))\n",
" \n",
" Z[k] += np.log(weights[k])\n",
" Z[k] -= 1/2. * (num_dim * np.log(2*np.pi) + np.log(np.linalg.det(covs[k])) + exponent_term)\n",
" \n",
" ll += log_sum_exp(Z) \n",
" return ll"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "7lZrctwadf9F",
"colab_type": "code",
"colab": {}
},
"source": [
"def EM(data, init_means, init_covariances, init_weights, maxiter=1000, thresh=1e-4):\n",
" \n",
" means = init_means[:]\n",
" covariances = init_covariances[:]\n",
" weights = init_weights[:]\n",
" \n",
" num_data = len(data)\n",
" num_dim = len(data[0])\n",
" num_clusters = len(means)\n",
" \n",
" resp = np.zeros((num_data, num_clusters))\n",
" ll = loglikelihood(data, weights, means, covariances)\n",
" ll_trace = [ll]\n",
" \n",
" for i in range(maxiter):\n",
" \n",
" # E-step: compute responsibilities\n",
" for j in range(num_data):\n",
" for k in range(num_clusters):\n",
" resp[j, k] = weights[k]*multivariate_normal.pdf(data[j],means[k],covariances[k], allow_singular=True)\n",
" row_sums = resp.sum(axis=1)[:, np.newaxis]\n",
" resp = resp / row_sums # normalize over all possible cluster assignments\n",
"\n",
" # M-step\n",
" counts = np.sum(resp, axis=0)\n",
" \n",
" for k in range(num_clusters):\n",
" \n",
" weights[k] = counts[k]/num_data\n",
" weighted_sum = 0\n",
" for j in range(num_data):\n",
" weighted_sum += (resp[j,k]*data[j])\n",
" means[k] = weighted_sum/counts[k]\n",
" \n",
" weighted_sum = np.zeros((num_dim, num_dim))\n",
" for j in range(num_data):\n",
" weighted_sum += (resp[j,k]*np.outer(data[j]-means[k],data[j]-means[k]))\n",
" covariances[k] = weighted_sum/counts[k]\n",
" \n",
" ll_latest = loglikelihood(data, weights, means, covariances)\n",
" ll_trace.append(ll_latest)\n",
" \n",
" if (ll_latest - ll) < thresh and ll_latest > -np.inf:\n",
" break\n",
" ll = ll_latest\n",
"\n",
" return resp\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BaPDeTN6kH9j",
"colab_type": "code",
"colab": {}
},
"source": [
"# Use sklearn to create embedings from sentence \n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"\n",
"documents = [\"Epidemic spread dengue\",\n",
" \"malraia spread\",\n",
" \"Epidemic vaccine\",\n",
" \"Epidemic spread fast\",\n",
" \"Computers are amazing\",\n",
" \"MacBook is a great computer\",\n",
" \"MacBook is expensive\",\n",
" \"Best cat photo I've ever taken.\",\n",
" \"google is amazing\",\n",
" \"Impressed with google map feedback.\",\n",
" \"Using google pixel is amazing\",\n",
" \"Google takes too much data\",\n",
" \"Google is bad\",\n",
" \"MacBook is an Apple product\"]\n",
"\n",
"vectorizer = TfidfVectorizer()\n",
"X = vectorizer.fit_transform(documents)\n",
"\n",
"X = X.toarray()\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gT2ZKPkkqLPW",
"colab_type": "code",
"colab": {}
},
"source": [
"chosen = np.random.choice(len(X), 3, replace=False)\n",
"initial_means = [X[x] for x in chosen]\n",
"initial_covs = [np.cov(X, rowvar=0) + np.eye(35) * 1e-4] * 3 \n",
"\n",
"initial_weights = [1/3.] * 3\n",
"# Use self defined EM method\n",
"results = EM(X, initial_means, initial_covs, initial_weights, maxiter=200)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_65fzbuC6dVt",
"colab_type": "code",
"outputId": "1a761a8e-0f14-4a2f-85ea-783a8ec98b09",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 141
}
},
"source": [
"res = results\n",
"culster = [-1]*100\n",
"for i in range(len(res)):\n",
" culster[i] = np.where(res[i] == np.amax(res[i]))[0][0]\n",
"\n",
"data_ = [[], [], []]\n",
"\n",
"for i in range(len(documents)):\n",
" data_[culster[i]].append(documents[i])\n",
"\n",
"for i in range(len(data_)):\n",
" print(\"Group\", i + 1)\n",
" print(data_[i])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Group 1\n",
"['Epidemic spread dengue', 'malraia spread', 'Epidemic vaccine', 'Epidemic spread fast', \"Best cat photo I've ever taken.\"]\n",
"Group 2\n",
"['MacBook is a great computer', 'MacBook is expensive', 'Impressed with google map feedback.', 'Google takes too much data', 'MacBook is an Apple product']\n",
"Group 3\n",
"['Computers are amazing', 'google is amazing', 'Using google pixel is amazing', 'Google is bad']\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment