Last active
August 29, 2015 14:01
-
-
Save indiejoseph/f7dc6982f57ef7508db6 to your computer and use it in GitHub Desktop.
BaumWelch Algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| 'use strict' | |
| _ = require 'lodash' | |
| MIN_FLOAT = -3.14e100 | |
| Object::default = (prop, value) -> | |
| @[prop] = value unless @hasOwnProperty(prop) | |
| Object::getValue = (prop, value) -> | |
| if @.hasOwnProperty prop | |
| return @[prop] | |
| else | |
| return value | |
| Math.sum = (arr) -> | |
| return _.reduce arr, (sum, el) -> | |
| return sum + el | |
| , 0 | |
| log = (value) -> | |
| if value > 0 | |
| value = Math.log value | |
| else | |
| value = MIN_FLOAT | |
| value | |
| class BaumWelch | |
| #参数列表:输出序列,状态序列,初始的状态分布概率,初始的状态转换概率,初始的发射概率 | |
| constructor: (outSeq, hiddenSeq, @states, @charStates={}, @probStart, @probTrans, @probEmit, @maxIterNum=20) -> | |
| @seq = _.zip outSeq, hiddenSeq | |
| @forwardNet = {} | |
| @backwardNet = {} | |
| for s in @states | |
| @probTrans[s] = {} unless @probTrans[s] | |
| @probEmit[s] = {} unless @probEmit[s] | |
| for c, i in outSeq | |
| s = hiddenSeq[i] | |
| @charStates[c] = [] unless @charStates.hasOwnProperty c | |
| @charStates[c].push s unless s in @charStates[c] | |
| buildnet: -> | |
| if @forwardNet is null or Object.keys(@forwardNet).length is 0 | |
| for sw, t in @seq | |
| probT = {} | |
| for s in @states | |
| emitP = if typeof(@probEmit[s][sw[0]]) isnt 'undefined' then @probEmit[s][sw[0]] else MIN_FLOAT | |
| if t is 0 | |
| probT[s] = @probStart[s] + emitP | |
| else | |
| probT[s] = log Math.sum((Math.pow(Math.E, (@forwardNet[t-1][s0] + @probTrans[s0].getValue(s, MIN_FLOAT) + emitP)) for s0 in @states)) | |
| @forwardNet[t] = probT | |
| if @backwardNet is null or Object.keys(@backwardNet).length is 0 | |
| T = @seq.length - 1 | |
| for i in [T..0] | |
| probT = {} | |
| if i is T | |
| #后向算法的初值设置,T表示序列的最后一个时刻 | |
| probT = ([s, 1] for s in @states) | |
| else | |
| for s in @states | |
| @probTrans[s] = {} unless @probTrans[s] | |
| probT[s] = log Math.sum((Math.pow(Math.E, (@backwardNet[i+1][s1] + @probTrans[s].getValue(s1, MIN_FLOAT) + @probEmit[s1].getValue(@seq[i+1][0], MIN_FLOAT))) for s1 in @states)) | |
| @backwardNet[i] = probT | |
| forward: (seq, states) -> | |
| # 获取某一个时刻序列的前向局部概率 | |
| if @forwardNet is not null and 0 in @forwardNet.keys() | |
| return @forwardNet[seq.length-1] | |
| probT = {} | |
| for sw, i in seq | |
| if i is 0 | |
| for s in states | |
| probT[s] = @probStart[s] + @probEmit[s].getValue(sw[0], MIN_FLOAT) | |
| else | |
| buf = {} | |
| buf[p[0]] = p[1] for p in probT | |
| for s in states | |
| probT[s] = log Math.sum((Math.pow(Math.E, (buf[s0] + @probTrans[s0].getValue(s, MIN_FLOAT) + @probEmit[s].getValue(sw[0], MIN_FLOAT))) for s0 in states)) | |
| return probT | |
| backward: (seq, states) -> | |
| # 获取某一个时刻序列的后向局部概率 | |
| if @backwardNet is not null and 0 in @backwardNet.keys() | |
| return @backwardNet[seq.length - 1] | |
| probT = {} | |
| T = seq.length - 1 | |
| for i in [T..0] | |
| if i is T | |
| #后向算法的初值设置,T表示序列的最后一个时刻 | |
| probT = ([s, 1] for s in states) | |
| else | |
| # console.log i | |
| buf = {} | |
| buf[p[0]] = p[1] for p in probT | |
| for s in states | |
| probT[s] = log Math.sum((Math.pow(Math.E, (buf[s1] + @probTrans[s].getValue(s1, MIN_FLOAT) + @probEmit[s1].getValue(seq[i+1][0], MIN_FLOAT))) for s1 in states)) | |
| return probT | |
| getr: (seq, t, states) -> | |
| # 获取t时刻的前向局部概率和后向局部概率的乘积并归一化,即r概率 | |
| probT = {} | |
| sum = 0.0 | |
| for s in states | |
| probForward = @forward seq.slice(0, t+1), states | |
| probBackward = @backward seq.slice(t), states | |
| probT[s] = probForward[s] + probBackward[s] | |
| sum += Math.pow Math.E, probT[s] | |
| buf = {} | |
| buf[p[0]] = p[1] for p in probT | |
| for s in states | |
| if Math.pow Math.E, buf[s] is 0.0 | |
| probT[s] = MIN_FLOAT | |
| else | |
| if Math.pow Math.E, buf[s] is 0.0 | |
| probT[s] = MIN_FLOAT | |
| else | |
| probT[s] = log (Math.pow(Math.E, buf[s]) / sum) | |
| return probT | |
| getxi: (seq, t, states) -> | |
| # 获取t时刻的ξ概率 | |
| xi = {} | |
| sum = 0.0 | |
| T = seq.length | |
| for s in states | |
| xi[s] = 1 | |
| if t is T - 1 | |
| return null | |
| probForward = @forward seq.slice(0, t+1), states | |
| probBackward = @backward seq.slice(t, T), states | |
| #console.log t, probBackward | |
| for s in states | |
| for s1 in states | |
| tmp = [s, s1] | |
| xi[tmp] = probForward[s]+probBackward[s1]+@probTrans[s].getValue(s1,MIN_FLOAT)+@probEmit[s1].getValue(seq[t+1][0],MIN_FLOAT) | |
| sum += Math.pow Math.E, xi[tmp] | |
| buf = {} | |
| buf[p[0]] = p[1] for p in xi | |
| for s in states | |
| for s1 in states | |
| tmp = [s, s1] | |
| if Math.pow(Math.E, buf[tmp]) is 0.0 | |
| xi[tmp] = MIN_FLOAT | |
| else | |
| xi[tmp] = log Math.pow(Math.E, buf[tmp] / sum) | |
| return xi | |
| doEM: (self) -> | |
| while @maxIterNum > 0 | |
| #E | |
| @buildnet() | |
| seqR = {} | |
| seqXi = {} | |
| for sw, i in @seq | |
| seqR[i] = @getr(@seq, i, @states) | |
| seqXi[i] = @getxi(@seq, i, @states) | |
| #M | |
| @updatePI seqR, @states | |
| @updateTrans seqR, seqXi, @states | |
| @updateEmit @seq, seqR, @states | |
| break if @ErrorIsOk @seq | |
| @maxIterNum-- | |
| updatePI: (seqR, states) -> | |
| # 更新初始概率 | |
| @probStart[s] = seqR[0][s] for s in states | |
| updateTrans: (seqR, seqXi, states) -> | |
| # 更新状态转移概率 | |
| for s in states | |
| sumR = 0.0 | |
| for t in [0...seqR.length - 1] | |
| sumR += Math.pow Math.E, seqR[t][s] | |
| for s1 in states | |
| sumXi = 0.0 | |
| for t in [0...seqR.length - 1] | |
| sumXi += Math.pow Math.E, seqXi[t][[s, s1]] | |
| @probTrans[s][s1] = log sumXi / sumR | |
| updateEmit: (seq, seqR, states) -> | |
| # 更新发射概率 | |
| for s in states | |
| sumR = 0.0 | |
| stateOutput = {} | |
| for t in [0...seqR.length] | |
| sumR += Math.pow Math.E, seqR[t][s] | |
| stateOutput.default seq[t][0], 0.0 | |
| stateOutput[seq[t][0]] += Math.pow Math.E, seqR[t][s] | |
| for o in Object.keys stateOutput | |
| @probEmit[o] = log(stateOutput[o]/sumR) | |
| ErrorIsOk: (seq) -> | |
| # 判断误差是否小于要求 | |
| [prob, path] = @viterbi((s[0] for s in @seq), @states, @probStart, @probTrans,@probEmit) | |
| for stateViterbi in path | |
| for right in seq | |
| console.log right, stateViterbi | |
| if right[1] isnt stateViterbi | |
| return no | |
| return yes | |
| viterbi: (obs, states, startP, transP, emitP) -> | |
| V = [{}] | |
| path = {} | |
| _maxFirst = (a, b) -> | |
| max = Math.max(a[0], b[0]) | |
| return if a[0] is max then a else b | |
| for y in states | |
| V[0][y] = startP[y] + emitP[y].getValue(obs[0], MIN_FLOAT) | |
| path[y] = [y] | |
| for t in [1...obs.length] | |
| V.push {} | |
| newpath = {} | |
| for y in states | |
| emP = emitP[y].getValue obs[t], MIN_FLOAT | |
| candidates = ([V[t-1][y0] + transP[y0].getValue(y, MIN_FLOAT) + emP, y0] for y0 in states) | |
| [prob, state] = candidates.reduce _maxFirst | |
| # console.log [prob, state] | |
| newpath[y] = path[state].concat y | |
| path = newpath | |
| candidates = ([V[obs.length - 1][y], y] for y in states) | |
| [prob, state] = candidates.reduce _maxFirst | |
| return [prob, path[state]] | |
| module.exports = BaumWelch if typeof module is "object" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| 'use strict' | |
| should = require 'should' | |
| HMM = require '../libs/hmm' | |
| BaumWelch = require '../libs/baum_welch' | |
| describe 'Baum Welch', -> | |
| hmm = null | |
| it 'should train a model', (done) -> | |
| seqs = [ | |
| '中/B,ns 國/E,ns 人/B,n 民/E,n 將/S,d 滿/B,l 懷/M,l 信/M,l 心/E,l 地/S,u 開/B,v 創/E,v 新/S,a 的/S,u 業/B,n 績/E,n 。/S,w' | |
| ] | |
| obsers = [[]] | |
| states = [[]] | |
| for s, i in seqs | |
| for token in s.split ' ' | |
| t = token.split '/' | |
| obsers[i].push t[0] | |
| states[i].push t[1] | |
| hmm = HMM.learn obsers, states | |
| done() | |
| it 'should train with existing model', (done) -> | |
| seqs = '中/B,nt 共/M,nt 中/M,nt 央/E,nt 總/B,n 書/M,n 記/E,n 、/S,w 國/B,n 家/E,n 主/B,n 席/E,n 江/B,nr 澤/M,nr 民/E,nr' | |
| obsers = [] | |
| states = [] | |
| for token in seqs.split ' ' | |
| t = token.split '/' | |
| obsers.push t[0] | |
| states.push t[1] | |
| bw = new BaumWelch obsers, states, hmm.states, hmm.charStates, hmm.pi, hmm.transProbs, hmm.emitProbs | |
| bw.doEM() | |
| testSeq = '中共中央總書記、國家主席江澤民' | |
| newHmm = new HMM bw.states, bw.charStates, bw.probStart, bw.probTrans, bw.probEmit | |
| [prob, predicts] = newHmm.viterbi testSeq.split('') | |
| done() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment