This Gist was automatically created by Carbide, a free online programming environment.
Last active
September 17, 2017 19:37
-
-
Save sliminality/b604e1965708237f65854224e459fdb6 to your computer and use it in GitHub Desktop.
Expectation Maximization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ///This notebook is based on the following tutorial paper: \\\\_"What is the expectation maximization algorithm?"_ by Chuong B Do & Serafim Batzoglou (http://ai.stanford.edu/~chuongdo/papers/em\_tutorial.pdf) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ///Consider the following scenario: we have two coins _A_ and _B_, and we want to learn their biases **ϴ** = [ ϴ\_A, ϴ\_B ].\\\\A trial proceeds as follows:\\(1) Select one of the coins uniformly and at random.\\(2) Toss the coin ten times, recording each outcome.\\\\A trial can be denoted as a pair _(z, x)_, where:\\- _z_ ∈ { _A_, _B_ } denotes the coin selected in step 1, and//- _x_ ∈ [0, 10] is the sufficient statistic denoting the number of heads observed among the ten tosses in step 2.\\Accordingly, we can represent a series of _N_ trials in two vectors:\\- _**Z**_ = [ _z1_, _z2_, ... , _zN_ ] denoting the sequence of coins selected, and//- _**X**_ = [ _x1_, _x2_, ... , _xN_ ] denoting the sequence of observed head counts across the _N_ trials. | |
| const data = [ | |
| { coin: 'B', tosses: [1, 0, 0, 0, 1, 1, 0, 1, 0, 1] }, | |
| { coin: 'A', tosses: [1, 1, 1, 1, 0, 1, 1, 1, 1, 1] }, | |
| { coin: 'A', tosses: [1, 0, 1, 1, 1, 1, 1, 0, 1, 1] }, | |
| { coin: 'B', tosses: [1, 0, 1, 0, 0, 0, 1, 1, 0, 0] }, | |
| { coin: 'A', tosses: [0, 1, 1, 1, 0, 1, 1, 1, 0, 1] }, | |
| ]; | |
| // GridArrayWidget | |
| export const dataRows = data.map((trial) => { | |
| const { coin, tosses } = trial; | |
| const heads = tosses.filter(n => n === 1).length; | |
| const tails = tosses.length - heads; | |
| const ht = `${heads} H, ${tails} T`; | |
| const coinA = coin === 'A' ? ht : ''; | |
| const coinB = coin === 'B' ? ht : ''; | |
| return { coin, heads, tails, coinA, coinB, }; | |
| }); | |
| // GridArrayWidget | |
| export const totals = [ 'A', 'B' ] | |
| .map(c => dataRows | |
| .filter(r => r.coin === c)) | |
| .map(rows => ({ | |
| coin: rows[0].coin, | |
| heads: rows.reduce((acc, curr) => acc + curr.heads, 0), | |
| tails: rows.reduce((acc, curr) => acc + curr.tails, 0) | |
| })); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import { X, numTosses } from './missingData'; | |
| const Z = ['A', 'B']; ///The latent variable _**Z**_ has two possible values, A or B. | |
| // 1. Get priors for theta | |
| const thetaPrior = { A: 0.6, B: 0.5 }; | |
| function EM (iterations, X, thetaPrior) { ///For each trial _x_ in _**X**_, we compute the expectation of _z_, using our current estimates for ϴ.\\\\Recall that the expectation will simply weight each possible value { _A, B _} according to its probability. | |
| let theta = thetaPrior; | |
| let expectedHeads = { A: 0, B: 0 }; | |
| for (let i = 0; i < iterations; i += 1) { | |
| // logLikelihood? | |
| // 2. E-step: Sum up expected # heads across all data points, for all values of Z | |
| for (const x of X) { | |
| let sum = 0; | |
| const pZGivenX = {}; | |
| // 2.1 First get the distribution P(z | x) | |
| for (const z of Z) { | |
| pZGivenX[z] = Math.pow(theta[z], x) * Math.pow(1 - theta[z], numTosses - x); | |
| // Increment the sum for normalizing | |
| sum += pZGivenX[z]; | |
| } | |
| // Increment log likelihood? | |
| // Normalize P(z | x) | |
| for (const z of Z) { | |
| pZGivenX[z] = pZGivenX[z] / sum; | |
| } | |
| // Weight our observations of h by probability of choosing each coin | |
| // aka add this example's contribution to the expected statistics (# x) | |
| for (const z of Z) { | |
| expectedHeads[z] += pZGivenX[z] * x; | |
| } | |
| } | |
| } | |
| return expectedHeads; | |
| } | |
| EM(10, X, thetaPrior); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ///Now, suppose we are given an incomplete dataset; namely, we are missing information about _**z**_. | |
| const tosses = [ | |
| [1, 0, 0, 0, 1, 1, 0, 1, 0, 1], | |
| [1, 1, 1, 1, 0, 1, 1, 1, 1, 1], | |
| [1, 0, 1, 1, 1, 1, 1, 0, 1, 1], | |
| [1, 0, 1, 0, 0, 0, 1, 1, 0, 0], | |
| [0, 1, 1, 1, 0, 1, 1, 1, 0, 1], | |
| ]; | |
| export const numTosses = tosses[0].length; | |
| const countOnes = (acc, curr) => curr === 1 ? acc + 1 : acc; ///Get sufficient statistics (number of heads observed) for each trial. | |
| export const X = tosses.map(trial => trial.reduce(countOnes, 0)); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ///Given a complete dataset, we can easily compute the MLE for **ϴ** = [ ϴ\_A, ϴ\_B ] using the formula:\\\\MLE = h / (h + t) | |
| import { totals } from './data'; | |
| const MLE = (h, t) => h / (h + t); | |
| // GridArrayWidget | |
| const mleTable = totals.map(record => ({ | |
| coin: record.coin, | |
| heads: record.heads, | |
| total: record.heads + record.tails, | |
| maxLikelihood: MLE(record.heads, record.tails), | |
| })); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment