Skip to content

Instantly share code, notes, and snippets.

@xiaohan2012
Created February 1, 2014 19:40
Show Gist options
  • Save xiaohan2012/8757566 to your computer and use it in GitHub Desktop.
Save xiaohan2012/8757566 to your computer and use it in GitHub Desktop.
A demo for Hidden Markov Model Inference. Forward-backward algorithm and Viterbi algorithm involved.
from __future__ import division
from collections import Counter, defaultdict
import operator
A = 'A'; H = 'H'
STATES = ['alpha', 'beta']; OBS = [A, H]
#given a list of training data, list of (sts, obs) pairs,
#derive the HMM model parameters
train_data = [ (('alpha', 'beta', 'beta', 'alpha', 'alpha', 'beta', 'beta'), (A, A, A, H, H, H, A)),
(('alpha', 'alpha', 'alpha', 'beta', 'alpha', 'alpha', 'beta'), (H, H, A, A, H, H, H)),
(('beta', 'alpha', 'alpha', 'alpha', 'beta', 'beta'), (A, A, A, A, H, A)),
(('alpha', 'beta', 'beta', 'alpha', 'alpha', 'alpha', 'beta', 'alpha'), (A, H, A, A, A, H, A, A))]
state_sequences = map (lambda row: row [0], train_data); obs_sequences = map (lambda row: row [1], train_data)
#the state transition probabilities
state_pairs = [(states [j], states [j+1]) for states in state_sequences for j in xrange (len (states) -1)]; state_pairs_freq = Counter (state_pairs)
flat_states_sequences = [state for states in state_sequences for state in states [:-1]]; states_freq = Counter (flat_states_sequences)
state_tbl = dict ( [((from_state, to_state), (state_pairs_freq [(from_state, to_state)] + 1) / (states_freq [from_state] + len (STATES)) )
for from_state in STATES
for to_state in STATES] )
print 'state transition probability'
print state_tbl
print
#the emission probs
state_obs_pairs = [pair for row in train_data for pair in zip (*row)];state_obs_freq = Counter (state_obs_pairs)
flat_states_sequences = [state for states in state_sequences for state in states ]; states_freq = Counter (flat_states_sequences)
emission_tbl = dict ( [( (state, obs), (state_obs_freq [(state, obs)] + 1) / (states_freq [state] + len (OBS)))
for state in STATES
for obs in OBS] )
print 'emission probability'
print emission_tbl
print
#the prior probs
starting_states = map (lambda r: r [0] [0], train_data); starting_states_freq = Counter (starting_states)
prior_tbl = dict ( [(s, (starting_states_freq [s] + 1) / (len (starting_states) + len (STATES)) ) for s in STATES] )
print 'prior probability'
print prior_tbl
print
hmm = {
'state_tbl': state_tbl,
'emission_tbl': emission_tbl,
'prior_tbl': prior_tbl
}
#the forward and outward algorithm to compute the posterior probabilities
def fb (obs, hmm):
"""
the backward-forward algorithm
Input: observations and the HMM model parameter
Output: the posterior distributions P (S_k | O_i ... O_t]])
"""
state_tbl = hmm ['state_tbl']; emission_tbl = hmm ['emission_tbl']; prior_tbl = hmm ['prior_tbl']
#forward
ftbl = defaultdict (dict)
for t in xrange(len(obs)):
ob = obs [t]
for s in STATES:
if t == 0:
ftbl[t] [s] = prior_tbl [s] * emission_tbl [(s, ob)]
else:
ftbl[t] [s] = emission_tbl [(s, ob)] * sum (ftbl [t-1] [ps] * state_tbl [(ps, s)] for ps in STATES)
#backward
btbl = defaultdict (dict)
for t in xrange(len(obs) - 1, -1, -1):
ob = obs [t]
for s in STATES:
if t == len (obs) - 1:
btbl[t] [s] = 1
else:
btbl[t] [s] = sum (emission_tbl [(ns, ob)] * btbl [t+1] [ns] * state_tbl [(s, ns)] for ns in STATES)
#the P (S_k | O_i \cdots \O_t]])
posterior = defaultdict(dict)
for k in xrange (len (obs)):
for s in STATES:
posterior [k] [s] = ftbl [k] [s] * btbl [k] [s]
#normalize
Z = sum (posterior [k].values ())
for s in STATES:
posterior [k] [s] /= Z
return posterior
def print_table (tbl, rows, cols):
"""
Util function
Print out the prob dist table
"""
print '\t'.join([''] + cols) + '\n'
print '\n'.join(['\t'.join(['%d' %r] + map(lambda val: '%.3f' %val, [tbl [r] [c] for c in cols])) for r in rows])
obs1, obs2 = [A,H,H,A,A], [H,A,A,H,A]
pos1 = fb (obs1, hmm)
print 'For observation ', ''.join(obs1), 'the posterior probability table is:'
print_table (pos1, xrange (len (obs1)), STATES)
print '\n\n'
pos2 = fb (obs2, hmm)
print 'For observation ', ''.join(obs2), 'the posterior probability table is:'
print_table (pos2, xrange (len (obs2)), STATES)
print '\n\n'
#viterbi algorithm
def viterbi (obs, hmm):
"""
The Viterbi Algorithm
"""
a,b,pi = hmm['state_tbl'], hmm['emission_tbl'], hmm['prior_tbl']
delta = defaultdict (dict)
bp = defaultdict (dict)
for t in xrange (len(obs)):
ob = obs [t]
for s in STATES:
if t == 0:
delta [t] [s] = pi [s] * b [(s, ob)]
else:
bp [t] [s], delta [t] [s] = max ([(ps, delta [t-1] [ps] * a [(ps, s)] * b [(s, ob)]) for ps in STATES],
key = operator.itemgetter (1))
#get the most probable last state
state, _ = max(delta [len (obs) - 1].items (), key = operator.itemgetter (1)); t = len (obs) - 1
state_sequence = [state]
#back tracing
while bp.has_key (t) and bp [t].has_key (state):
state = bp [t] [state]
t -= 1
state_sequence.append (state)
return delta, state_sequence [::-1]
_, seq1 = viterbi (list ('AHHAA'), hmm)
print 'for observations AHHAA, the most probable state sequence is: ' + ' '.join (seq1)
_, seq2 = viterbi (list ('HAAHA'), hmm)
print 'for observations HAAHA, the most probable state sequence is: ' + ' '.join (seq2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment