Last active
May 29, 2020 08:50
-
-
Save AdolfVonKleist/11108535 to your computer and use it in GitHub Desktop.
compute-best-mix.py : Python port of the venerated SRILM tool
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import re, math | |
def LoadPPLFile (pplfile) : | |
""" | |
Load up the salient info from a -debug 2 PPL file | |
generated by the SRILM ngram tool. | |
""" | |
ppl_info = [] | |
for line in open (pplfile, "r") : | |
line = line.strip() | |
if line.startswith ("p(") : | |
tok, prob = line.split ("=") | |
probs = re.split (r"\s+", re.sub (r"[\[\]]", "", prob).strip()) | |
tok = re.sub (r"^p\( ", "", tok) | |
tok = re.sub (r" \|.*$", "", tok) | |
ppl_info.append ([tok, int(re.sub(r"gram", "",probs[0])), float(probs[2])]) | |
print len (ppl_info) | |
print "\t", ppl_info[2] | |
return ppl_info | |
class MixtureComputer () : | |
""" | |
Python port of SRILM gawk tool 'compute-best-mix' | |
Should produce the same result as: | |
$ compute-best-mix ppl.fil1 ppl.file2 ... ppl.fileN | |
""" | |
def __init__ (self, ppl_infos, lambdas=[], precision=0.001, verbose=False) : | |
self.M_LN10 = 2.30258509299404568402 | |
self.logINF = -320 | |
self.precision = precision | |
self.ppl_infos = ppl_infos | |
self.lambdas, self.priors = self._init_lambdas (lambdas) | |
self.max_iter = 100 | |
self.post_totals = [] | |
self.nc = len (self.ppl_infos) #Number of components | |
self.log10priors = [] | |
def _init_lambdas (self, lambdas) : | |
if len (lambdas) == 0 : | |
lambdas = [1./len(self.ppl_infos) for l in xrange(len(self.ppl_infos))] | |
lambda_sum = 0.0 | |
priors_ = [0.0 for l in lambdas] | |
for i,l in enumerate (lambdas) : | |
priors_[i] = l | |
lambda_sum += l | |
return lambdas, priors_ | |
def _sum_word (self, i) : | |
log_posts_ = [self.log10priors[j] + self.ppl_infos[j][i][2] | |
for j in xrange(self.nc)] | |
log_sum_ = log_posts_[0] | |
for log_post_ in log_posts_[1:] : | |
log_sum_ = math.log ( | |
(math.pow (10, log_sum_) + math.pow (10, log_post_)), | |
10) | |
return log_sum_, log_posts_ | |
def OptimizeLambdas (self) : | |
""" | |
So how does this actually work? There is no explanation except | |
the source where the original gawk script is concerned. | |
It is basically a simple, iterative EM-like estimation procedure. | |
1. Load all the PPL results from the component models | |
2. Initialize the original mixture weights | |
3. For each word in the test set, compute the lambda-scaled sum | |
for each of the component models for this word. | |
For example, | |
* word = WORD1, | |
* models = M1, M2, M3 | |
* lambdas = L1, L2, L3 | |
Compute log posteriors: log10(LN) + WORD1 | |
Compute the log sum of the posteriors for this word. | |
4. Compute the per-model posterior totals | |
This is the per-model log posterior from (3.) | |
divided by the total (word-based) log sum from (3.) | |
5. Recompute the lambda priors, normalizing by the total | |
number of (non-OOV) words in the test set | |
6. Finally, determine the actual, absolute change between | |
the previous prior values, and the newly recomputed ones. | |
If the values for any of the models is larger than the | |
precision threshold, and we have not reached the max | |
number of iterations, return to Step 3. | |
The algorithm terminates when either the max-iters is reached | |
or the total change for all models dips below the threshold. | |
""" | |
have_converged = False | |
iteration = 0 | |
while not have_converged : | |
iteration += 1 | |
log_like = 0.0 | |
post_totals = [0.0 for p in self.ppl_infos] | |
self.log10priors = [math.log (self.priors[i], 10) | |
for i in xrange(self.nc)] | |
for i in xrange(len(self.ppl_infos[0])) : | |
# Compute the sum for this word, across all components | |
log_sum, log_posts = self._sum_word (i) | |
log_like += log_sum | |
for j in xrange(len(self.ppl_infos)) : | |
post_totals[j] += math.pow (10, log_posts[j] - log_sum) | |
print iteration, \ | |
" ".join([str(x) for x in self.priors]), \ | |
math.pow (10, -log_like / len(self.ppl_infos[0])) | |
have_converged = True | |
for j in xrange(len(self.ppl_infos)) : | |
last_prior = self.priors[j] | |
self.priors[j] = post_totals[j] / len(self.ppl_infos[0]) | |
abs_change = abs (last_prior - self.priors[j]) | |
if abs_change > self.precision : | |
have_converged = False | |
if iteration > self.max_iter : | |
have_converged = True | |
return | |
if __name__=="__main__" : | |
import sys, argparse | |
example = "USAGE: {0} --ppl ppl.1.txt,ppl.2.txt,ppl.3.txt".format (sys.argv[0]) | |
parser = argparse.ArgumentParser (description = example) | |
parser.add_argument ("--ppl", "-p", help="List of ppl files from 'ngram'.", required=True) | |
parser.add_argument ("--verbose", "-v", help="Verbose mode.", default=False, action="store_true") | |
args = parser.parse_args () | |
pplfiles = args.ppl.split (",") | |
pplinfos = [] | |
for f in pplfiles : | |
pplinfos.append (LoadPPLFile (f)) | |
mixer = MixtureComputer (pplinfos) | |
mixer.OptimizeLambdas () | |
print mixer.priors |
You need to change line 18 something like
ppl_info.append ([tok, int(re.sub(r"gram", "",probs[0].replace("OOV", "0"))), float(probs[2])])
Also it is good to deal with -inf if you have OOVs in your text.
Hello,do you know that what is the weight updating method using in compute-best-mix shell ? gradient descent ?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello. Looks cool.
But right now I'm getting:
$ compute-best-mix.py --ppl ppl1,ppl2
Traceback (most recent call last):
File "compute-best-mix.py", line 137, in
pplinfos.append (LoadPPLFile (f))
File "compute-best-mix.py", line 18, in LoadPPLFile
ppl_info.append ([tok, int(re.sub(r"gram", "",probs[0])), float(probs[2])])
ValueError: invalid literal for int() with base 10: 'OOV'