-
-
Save trappedinspacetime/e36fd5b6f62621594c9e01297c00ee22 to your computer and use it in GitHub Desktop.
Ngrams, LMs, and Perplexity in AWK and sed
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
ngram_count_file="$1" | |
lm_file="$2" | |
awk -v f="$ngram_count_file"\ | |
'function log10(x){ return log(x)/log(10.0) } | |
BEGIN { | |
while (getline < f) { | |
if(/[0-9]+\t[^ ]+$/) { type[1]++; token[1]=token[1]+$1 } | |
if(/[0-9]+\t[^ ]+ [^ ]+$/) { type[2]++; token[2]=token[2]+$1 } | |
if(/[0-9]+\t[^ ]+ [^ ]+ [^ ]+$/) { type[3]++; token[3]=token[3]+$1 } | |
} | |
print "\\data\\" | |
print "ngram 1: type=" type[1], "token=" token[1] | |
print "ngram 2: type=" type[2], "token=" token[2] | |
print "ngram 3: type=" type[3], "token=" token[3] | |
} | |
/[0-9]+\t[^ ]+$/ { | |
if (inblock==0) { print "\n\\1-grams:"; inblock=1 } | |
w1=$2; c[w1]=$1 | |
p=c[w1]/token[1] | |
printf "%d %.15g %.15g %s\n",c[w1],p,log10(p),w1 | |
} | |
/[0-9]+\t[^ ]+ [^ ]+$/ { | |
if (inblock==1) { print "\n\\2-grams:"; inblock=2 } | |
w1=$2; w2=$3; c[w1,w2]=$1 | |
p=c[w1,w2]/c[w1] | |
printf "%d %.15g %.15g %s %s\n",c[w1,w2],p,log10(p),w1,w2 | |
} | |
/[0-9]+\t[^ ]+ [^ ]+ [^ ]+$/ { | |
if (inblock==2) { print "\n\\3-grams:"; inblock=3 } | |
w1=$2; w2=$3; w3=$4; c[w1,w2,w3]=$1 | |
p=c[w1,w2,w3]/c[w1,w2] | |
printf "%d %.15g %.15g %s %s %s\n",c[w1,w2,w3],p,log10(p),w1,w2,w3 | |
} | |
END { | |
print "\n\\end\\" | |
}' < "$ngram_count_file" > "$lm_file" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
trainingdata="$1" | |
ngramcounts="$2" | |
ngram () { | |
n=$(( $1 - 1 )) | |
awk '{print "<s> " $0 " </s>"}' \ | |
| sed -rn " | |
:start | |
/\s*(\S+)(\s\S+){$n}/ { | |
s/\s*(\S+)((\s\S+){$n})(.*)$/\1\2\\n\2\4/ | |
h | |
s/\\n.*// | |
p | |
g | |
s/^[^\\n]*\\n\s*// | |
t start | |
}" | |
} | |
count () { | |
awk '{ freq[$0]++ } END { for (g in freq) { print freq[g] "\t" g } }' | |
} | |
ngram 1 < "$trainingdata" | count | sort -nr > "$ngramcounts" | |
ngram 2 < "$trainingdata" | count | sort -nr >> "$ngramcounts" | |
ngram 3 < "$trainingdata" | count | sort -nr >> "$ngramcounts" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
lm_file="$1" | |
L1="$2" | |
L2="$3" | |
L3="$4" | |
test_data="$5" | |
output_file="$6" | |
awk -v lmfile="$lm_file" -v L1="$L1" -v L2="$L2" -v L3="$L3"\ | |
'function log10(x) { return log(x)/log(10) } | |
function printword(n, w1, w2, w3, lm) { | |
if (w1=="") { condition=w2 } else { condition=w1 " " w2 } | |
if (! (w3 in lm)) { | |
printf "%d: lg P(%s | %s) = -inf (unknown word)\n", n, w3, condition | |
return -inf | |
} else { | |
post="" | |
p3=lm[w1 " " w2 " " w3]; p2=lm[w2 " " w3]; p1=lm[w3] | |
if (p1=="") { p1=0; post=" (unseen ngrams)" } | |
if (p2=="") { p2=0; post=" (unseen ngrams)" } | |
if (p3=="") { p3=0; if (w1!="") { post=" (unseen ngrams)" } } | |
p=log10((L3*p3)+(L2*p2)+(L1*p1)) | |
printf "%d: lg P(%s | %s) = %.15g%s\n", n, w3, condition, p, post | |
return p | |
} | |
} | |
BEGIN { | |
while (getline < lmfile) { | |
if (/^[0-9]+ [^ ]+ [^ ]+ [^ ]+$/) | |
{ lm[$4]=$2 } | |
if (/^[0-9]+ [^ ]+ [^ ]+ [^ ]+ [^ ]+$/) | |
{ lm[$4 " " $5]=$2 } | |
if (/^[0-9]+ [^ ]+ [^ ]+ [^ ]+ [^ ]+ [^ ]+$/) | |
{ lm[$4 " " $5 " " $6]=$2 } | |
} | |
} | |
/[^\s]+/ { | |
sent_cnt++ | |
s="<s> " $0 " </s>" | |
sum=0; cnt=0; oov=0 | |
printf "\nSent #%d: %s\n", sent_cnt, s | |
split(s, tokens) | |
p=printword(1, "", tokens[1], tokens[2], lm) | |
if (p != -inf) { sum=sum+p; cnt++ } else { oov++ } | |
for (i=3; i<NF+3; i++) { | |
p=printword(i-1, tokens[i-2], tokens[i-1], tokens[i], lm) | |
if (p != -inf) { sum=sum+p; cnt++ } else { oov++ } | |
} | |
print "1 sentence,", NF, "words,", oov, "OOVs" | |
ppl = 10**((sum * -1)/cnt) | |
printf "lgprob=%.15g ppl=%.15g\n\n\n", sum, ppl | |
total_sum = total_sum + sum | |
total_cnt = total_cnt + cnt | |
total_oov = total_oov + oov | |
} | |
END { | |
avglp = total_sum/total_cnt | |
ppl = 10**(avglp * -1) | |
words = total_cnt + total_oov - sent_cnt | |
print "\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" | |
print "sent_num=" sent_cnt " word_num=" words " oov_num=" total_oov | |
printf "lgprob=%.15g ave_lgprob=%.15g ppl=%.15g", total_sum, avglp, ppl | |
}' < "$test_data" > "$output_file" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment