Created
July 27, 2012 15:44
-
-
Save AdolfVonKleist/3188749 to your computer and use it in GitHub Desktop.
Expectation and Maximization functions from M2MFstAligner
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void M2MFstAligner::expectation( ){ | |
for( int i=0; i<fsas.size(); i++ ){ | |
//Comput Forward and Backward probabilities | |
ShortestDistance( fsas.at(i), &alpha ); | |
ShortestDistance( fsas.at(i), &beta, true ); | |
//Compute the normalized Gamma probabilities and | |
// update our running tally | |
for( StateIterator<VectorFst<LogArc> > siter(fsas.at(i)); !siter.Done(); siter.Next() ){ | |
LogArc::StateId q = siter.Value(); | |
for( ArcIterator<VectorFst<LogArc> > aiter(fsas.at(i),q); !aiter.Done(); aiter.Next() ){ | |
const LogArc& arc = aiter.Value(); | |
const LogWeight& gamma = Divide(Times(Times(alpha[q], arc.weight), beta[arc.nextstate]), beta[0]); | |
//Check for any BadValue results, otherwise add to the tally. | |
//We call this 'prev_alignment_model' which may seem misleading, but | |
// this conventions leads to 'alignment_model' being the final version. | |
if( gamma.Value()==gamma.Value() ){ | |
prev_alignment_model[arc.ilabel] = Plus(prev_alignment_model[arc.ilabel], gamma); | |
total = Plus(total, gamma); | |
} | |
} | |
} | |
alpha.clear(); | |
beta.clear(); | |
} | |
} | |
float M2MFstAligner::maximization( bool lastiter ){ | |
//Maximization. Simple count normalization. Probably get an improvement | |
// by using a more sophisticated regularization approach. | |
map<LogArc::Label,LogWeight>::iterator it; | |
float change = abs(total.Value()-prevTotal.Value()); | |
//cout << "Total: " << total << " Change: " << abs(total.Value()-prevTotal.Value()) << endl; | |
prevTotal = total; | |
//Normalize and iterate to the next model. We apply it dynamically | |
// during the expectation step. | |
for( it=prev_alignment_model.begin(); it != prev_alignment_model.end(); it++ ){ | |
alignment_model[(*it).first] = Divide((*it).second,total); | |
(*it).second = LogWeight::Zero(); | |
} | |
for( int i=0; i<fsas.size(); i++ ){ | |
for( StateIterator<VectorFst<LogArc> > siter(fsas[i]); !siter.Done(); siter.Next() ){ | |
LogArc::StateId q = siter.Value(); | |
for( MutableArcIterator<VectorFst<LogArc> > aiter(&fsas[i], q); !aiter.Done(); aiter.Next() ){ | |
LogArc arc = aiter.Value(); | |
arc.weight = alignment_model[arc.ilabel]; | |
aiter.SetValue(arc); | |
} | |
} | |
} | |
total = LogWeight::Zero(); | |
return change; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment