Last active
December 15, 2015 15:49
-
-
Save kmike/5285124 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.misc import logsumexp | |
def _forward(n_observations, n_components, log_startprob, log_transmat, | |
framelogprob): | |
fwdlattice = np.empty((n_observations, n_components)) | |
fwdlattice[0] = log_startprob + framelogprob[0] | |
for t in range(1, n_observations): | |
summand = fwdlattice[t-1] + log_transmat.T | |
l_sum = logsumexp(summand, axis=1) | |
# FIXME: how to vectorize this loop? | |
for j in range(n_components): | |
fwdlattice[t, j] = l_sum[j] + framelogprob[t, j] | |
return fwdlattice | |
def _backward(n_observations, n_components, log_transmat, framelogprob): | |
bwdlattice = np.empty((n_observations, n_components)) | |
bwdlattice[n_observations - 1] = 0.0 | |
for t in range(n_observations - 2, -1, -1): | |
summand = log_transmat + framelogprob[t+1] + bwdlattice[t+1] | |
bwdlattice[t] = logsumexp(summand, axis=1) | |
return bwdlattice | |
def _viterbi(n_observations, n_components, log_startprob, log_transmat, | |
framelogprob): | |
# Initialization | |
state_sequence = np.empty(n_observations, dtype=np.int) | |
viterbi_lattice = np.zeros((n_observations, n_components)) | |
viterbi_lattice[0] = log_startprob + framelogprob[0] | |
# Induction | |
for t in range(1, n_observations): | |
work_buffer = viterbi_lattice[t-1] + log_transmat.T | |
viterbi_lattice[t] = np.max(work_buffer, axis=1) + framelogprob[t] | |
# Observation traceback | |
max_pos = np.argmax(viterbi_lattice[n_observations - 1, :]) | |
state_sequence[n_observations - 1] = max_pos | |
logprob = viterbi_lattice[n_observations - 1, max_pos] | |
for t in range(n_observations - 2, -1, -1): | |
max_pos = np.argmax(viterbi_lattice[t, :] \ | |
+ log_transmat[:, state_sequence[t + 1]]) | |
state_sequence[t] = max_pos | |
return state_sequence, logprob |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment