Skip to content

Instantly share code, notes, and snippets.

@indiejoseph
Last active August 29, 2015 14:01
Show Gist options
  • Select an option

  • Save indiejoseph/f7dc6982f57ef7508db6 to your computer and use it in GitHub Desktop.

Select an option

Save indiejoseph/f7dc6982f57ef7508db6 to your computer and use it in GitHub Desktop.
BaumWelch Algorithm
'use strict'
_ = require 'lodash'
MIN_FLOAT = -3.14e100
Object::default = (prop, value) ->
@[prop] = value unless @hasOwnProperty(prop)
Object::getValue = (prop, value) ->
if @.hasOwnProperty prop
return @[prop]
else
return value
Math.sum = (arr) ->
return _.reduce arr, (sum, el) ->
return sum + el
, 0
log = (value) ->
if value > 0
value = Math.log value
else
value = MIN_FLOAT
value
class BaumWelch
#参数列表:输出序列,状态序列,初始的状态分布概率,初始的状态转换概率,初始的发射概率
constructor: (outSeq, hiddenSeq, @states, @charStates={}, @probStart, @probTrans, @probEmit, @maxIterNum=20) ->
@seq = _.zip outSeq, hiddenSeq
@forwardNet = {}
@backwardNet = {}
for s in @states
@probTrans[s] = {} unless @probTrans[s]
@probEmit[s] = {} unless @probEmit[s]
for c, i in outSeq
s = hiddenSeq[i]
@charStates[c] = [] unless @charStates.hasOwnProperty c
@charStates[c].push s unless s in @charStates[c]
buildnet: ->
if @forwardNet is null or Object.keys(@forwardNet).length is 0
for sw, t in @seq
probT = {}
for s in @states
emitP = if typeof(@probEmit[s][sw[0]]) isnt 'undefined' then @probEmit[s][sw[0]] else MIN_FLOAT
if t is 0
probT[s] = @probStart[s] + emitP
else
probT[s] = log Math.sum((Math.pow(Math.E, (@forwardNet[t-1][s0] + @probTrans[s0].getValue(s, MIN_FLOAT) + emitP)) for s0 in @states))
@forwardNet[t] = probT
if @backwardNet is null or Object.keys(@backwardNet).length is 0
T = @seq.length - 1
for i in [T..0]
probT = {}
if i is T
#后向算法的初值设置,T表示序列的最后一个时刻
probT = ([s, 1] for s in @states)
else
for s in @states
@probTrans[s] = {} unless @probTrans[s]
probT[s] = log Math.sum((Math.pow(Math.E, (@backwardNet[i+1][s1] + @probTrans[s].getValue(s1, MIN_FLOAT) + @probEmit[s1].getValue(@seq[i+1][0], MIN_FLOAT))) for s1 in @states))
@backwardNet[i] = probT
forward: (seq, states) ->
# 获取某一个时刻序列的前向局部概率
if @forwardNet is not null and 0 in @forwardNet.keys()
return @forwardNet[seq.length-1]
probT = {}
for sw, i in seq
if i is 0
for s in states
probT[s] = @probStart[s] + @probEmit[s].getValue(sw[0], MIN_FLOAT)
else
buf = {}
buf[p[0]] = p[1] for p in probT
for s in states
probT[s] = log Math.sum((Math.pow(Math.E, (buf[s0] + @probTrans[s0].getValue(s, MIN_FLOAT) + @probEmit[s].getValue(sw[0], MIN_FLOAT))) for s0 in states))
return probT
backward: (seq, states) ->
# 获取某一个时刻序列的后向局部概率
if @backwardNet is not null and 0 in @backwardNet.keys()
return @backwardNet[seq.length - 1]
probT = {}
T = seq.length - 1
for i in [T..0]
if i is T
#后向算法的初值设置,T表示序列的最后一个时刻
probT = ([s, 1] for s in states)
else
# console.log i
buf = {}
buf[p[0]] = p[1] for p in probT
for s in states
probT[s] = log Math.sum((Math.pow(Math.E, (buf[s1] + @probTrans[s].getValue(s1, MIN_FLOAT) + @probEmit[s1].getValue(seq[i+1][0], MIN_FLOAT))) for s1 in states))
return probT
getr: (seq, t, states) ->
# 获取t时刻的前向局部概率和后向局部概率的乘积并归一化,即r概率
probT = {}
sum = 0.0
for s in states
probForward = @forward seq.slice(0, t+1), states
probBackward = @backward seq.slice(t), states
probT[s] = probForward[s] + probBackward[s]
sum += Math.pow Math.E, probT[s]
buf = {}
buf[p[0]] = p[1] for p in probT
for s in states
if Math.pow Math.E, buf[s] is 0.0
probT[s] = MIN_FLOAT
else
if Math.pow Math.E, buf[s] is 0.0
probT[s] = MIN_FLOAT
else
probT[s] = log (Math.pow(Math.E, buf[s]) / sum)
return probT
getxi: (seq, t, states) ->
# 获取t时刻的ξ概率
xi = {}
sum = 0.0
T = seq.length
for s in states
xi[s] = 1
if t is T - 1
return null
probForward = @forward seq.slice(0, t+1), states
probBackward = @backward seq.slice(t, T), states
#console.log t, probBackward
for s in states
for s1 in states
tmp = [s, s1]
xi[tmp] = probForward[s]+probBackward[s1]+@probTrans[s].getValue(s1,MIN_FLOAT)+@probEmit[s1].getValue(seq[t+1][0],MIN_FLOAT)
sum += Math.pow Math.E, xi[tmp]
buf = {}
buf[p[0]] = p[1] for p in xi
for s in states
for s1 in states
tmp = [s, s1]
if Math.pow(Math.E, buf[tmp]) is 0.0
xi[tmp] = MIN_FLOAT
else
xi[tmp] = log Math.pow(Math.E, buf[tmp] / sum)
return xi
doEM: (self) ->
while @maxIterNum > 0
#E
@buildnet()
seqR = {}
seqXi = {}
for sw, i in @seq
seqR[i] = @getr(@seq, i, @states)
seqXi[i] = @getxi(@seq, i, @states)
#M
@updatePI seqR, @states
@updateTrans seqR, seqXi, @states
@updateEmit @seq, seqR, @states
break if @ErrorIsOk @seq
@maxIterNum--
updatePI: (seqR, states) ->
# 更新初始概率
@probStart[s] = seqR[0][s] for s in states
updateTrans: (seqR, seqXi, states) ->
# 更新状态转移概率
for s in states
sumR = 0.0
for t in [0...seqR.length - 1]
sumR += Math.pow Math.E, seqR[t][s]
for s1 in states
sumXi = 0.0
for t in [0...seqR.length - 1]
sumXi += Math.pow Math.E, seqXi[t][[s, s1]]
@probTrans[s][s1] = log sumXi / sumR
updateEmit: (seq, seqR, states) ->
# 更新发射概率
for s in states
sumR = 0.0
stateOutput = {}
for t in [0...seqR.length]
sumR += Math.pow Math.E, seqR[t][s]
stateOutput.default seq[t][0], 0.0
stateOutput[seq[t][0]] += Math.pow Math.E, seqR[t][s]
for o in Object.keys stateOutput
@probEmit[o] = log(stateOutput[o]/sumR)
ErrorIsOk: (seq) ->
# 判断误差是否小于要求
[prob, path] = @viterbi((s[0] for s in @seq), @states, @probStart, @probTrans,@probEmit)
for stateViterbi in path
for right in seq
console.log right, stateViterbi
if right[1] isnt stateViterbi
return no
return yes
viterbi: (obs, states, startP, transP, emitP) ->
V = [{}]
path = {}
_maxFirst = (a, b) ->
max = Math.max(a[0], b[0])
return if a[0] is max then a else b
for y in states
V[0][y] = startP[y] + emitP[y].getValue(obs[0], MIN_FLOAT)
path[y] = [y]
for t in [1...obs.length]
V.push {}
newpath = {}
for y in states
emP = emitP[y].getValue obs[t], MIN_FLOAT
candidates = ([V[t-1][y0] + transP[y0].getValue(y, MIN_FLOAT) + emP, y0] for y0 in states)
[prob, state] = candidates.reduce _maxFirst
# console.log [prob, state]
newpath[y] = path[state].concat y
path = newpath
candidates = ([V[obs.length - 1][y], y] for y in states)
[prob, state] = candidates.reduce _maxFirst
return [prob, path[state]]
module.exports = BaumWelch if typeof module is "object"
'use strict'
should = require 'should'
HMM = require '../libs/hmm'
BaumWelch = require '../libs/baum_welch'
describe 'Baum Welch', ->
hmm = null
it 'should train a model', (done) ->
seqs = [
'中/B,ns 國/E,ns 人/B,n 民/E,n 將/S,d 滿/B,l 懷/M,l 信/M,l 心/E,l 地/S,u 開/B,v 創/E,v 新/S,a 的/S,u 業/B,n 績/E,n 。/S,w'
]
obsers = [[]]
states = [[]]
for s, i in seqs
for token in s.split ' '
t = token.split '/'
obsers[i].push t[0]
states[i].push t[1]
hmm = HMM.learn obsers, states
done()
it 'should train with existing model', (done) ->
seqs = '中/B,nt 共/M,nt 中/M,nt 央/E,nt 總/B,n 書/M,n 記/E,n 、/S,w 國/B,n 家/E,n 主/B,n 席/E,n 江/B,nr 澤/M,nr 民/E,nr'
obsers = []
states = []
for token in seqs.split ' '
t = token.split '/'
obsers.push t[0]
states.push t[1]
bw = new BaumWelch obsers, states, hmm.states, hmm.charStates, hmm.pi, hmm.transProbs, hmm.emitProbs
bw.doEM()
testSeq = '中共中央總書記、國家主席江澤民'
newHmm = new HMM bw.states, bw.charStates, bw.probStart, bw.probTrans, bw.probEmit
[prob, predicts] = newHmm.viterbi testSeq.split('')
done()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment