Created
February 7, 2013 00:00
-
-
Save kepler/4727084 to your computer and use it in GitHub Desktop.
Variable length n-gram markov chain for sequence learning.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Fabio N. Kepler, Sergio L. S. Mergen, Cleo Z. Billa, Jose C. Bins | |
2012 | |
First written for the PAutomaC Competition, 2012. | |
Check URL for file format details. | |
(http://ai.cs.umbc.edu/icgi2012/challenge/Pautomac) | |
You are free to use it. We would like if you let us know ([email protected]). | |
Version: 2 | |
""" | |
import math | |
import sys | |
import time | |
import collections | |
import argparse | |
def number(arg): | |
return float(arg) | |
class VLNGramTree: | |
def __init__(self, rootLabel='ROOT', max_depth=10, cut_value=0.0, density=1): | |
self.root = self.Node(rootLabel) | |
self.max_depth = max_depth | |
self.cut_value = cut_value | |
self.density = density | |
self.size = 0 | |
self.depth = 1 | |
def train(self, data): | |
self.build(data) | |
self.shake() | |
self.prune() | |
def build(self, data): | |
window = self.max_depth | |
for sequence in data: | |
ngramseq = [-1] + sequence + [-2] | |
for end in range(1,len(ngramseq)): | |
start = max(end - window, 0) | |
self.addSequence(list(reversed(ngramseq[start:end])), ngramseq[end]) | |
def addSequence(self, sequence, dest): | |
self.root.addSprout(dest) | |
curNode = self.root | |
num_new_nodes = 0 | |
for label in sequence: | |
if curNode.children.has_key(label): | |
curNode.children[label].addSprout(dest) | |
curNode = curNode.children[label] | |
else: | |
newNode = self.Node(label, parent=curNode) | |
newNode.addSprout(dest) | |
curNode = newNode | |
num_new_nodes += 1 | |
def shake(self, protect=[-1, -2]): | |
''' | |
Discard nodes occurring less times than @self.density. | |
''' | |
queue = collections.deque([self.root]) | |
discarded = 0 | |
while len(queue) > 0: | |
node = queue.popleft() | |
if node.label in protect: continue # Do not discard sentence border markers! | |
if node.count < self.density: | |
discarded += 1 | |
node.parent.children.pop(node.label) | |
del node | |
else: | |
queue.extend(node.children.values() or []) | |
self.size -= discarded | |
def prune(self): | |
''' | |
Discard leaves that divergence less than @self.cut_value from its parent. | |
''' | |
num_cut = 0 | |
leaves = self.getLeaves() | |
while leaves: | |
leaf = leaves.popleft() | |
if not leaf.parent: continue | |
#if leaf.label == -1 or leaf.label == -2: continue # Do not prune sentence border markers. | |
divergence = self.getDivergence(leaf, leaf.parent) # according to a paper by Pereira | |
#divergence = self.getDivergence(leaf.parent, leaf) # according to other sources | |
if divergence < self.cut_value: # or > depending on the direction of the divergence | |
leaf.parent.children.pop(leaf.label) | |
if not leaf.parent.children: #len(leaf.parent.children) == 0: | |
leaves.append(leaf.parent) | |
del leaf | |
num_cut += 1 | |
else: # leaf will stay | |
self.depth = max(self.depth, leaf.depth) | |
self.size -= num_cut | |
def getDivergence(self, P, Q): | |
''' | |
Using Kullback-Leibler divergence (with absolute discounting smoothing). | |
(http://www.cs.bgu.ac.il/~elhadad/nlp09/KL.html) | |
''' | |
DKL = number(0.0) | |
SP = set(P.sprouts.keys()) | |
SQ = set(Q.sprouts.keys()) | |
SU = SP.union(SQ) | |
epsilon = number(0.0000000001) | |
pcount = number(epsilon * len(SU - SP)) / number(len(SP)) | |
qcount = number(epsilon * len(SU - SQ)) / number(len(SQ)) | |
for i in SU: | |
pi = P.getSproutLogProbability(i) | |
if pi: pi = pi - pcount | |
else: pi = epsilon | |
qi = Q.getSproutLogProbability(i) | |
if qi: qi = qi - qcount | |
else: qi = epsilon | |
DKL += number(math.exp(pi)) * number(pi - qi) | |
# Weighting by probability of P | |
P_prob = number(P.count) / number(self.root.count) | |
DKL = P_prob * DKL | |
return DKL | |
#-------------------------------------------------------------------------- | |
def getSequenceLogProbability(self, sequence, dest): | |
probs = collections.deque([self.root.getSproutLogProbability(dest)]) | |
curNode = self.root | |
for label in sequence: | |
if curNode.children.has_key(label): | |
p = curNode.children[label].getSproutLogProbability(dest) | |
if p: | |
probs.append(p) | |
else: | |
break | |
curNode = curNode.children[label] | |
else: | |
break | |
return number(probs[-1]) | |
def predictProbabilities(self, testData): | |
probs = collections.deque([]) | |
window = self.max_depth | |
for sequence in testData: | |
prob = number(0.0) | |
ngramseq = [-1] + sequence + [-2] | |
for end in range(1,len(ngramseq)): | |
start = max(end - window, 0) | |
p = self.getSequenceLogProbability(list(reversed(ngramseq[start:end])), ngramseq[end]) | |
prob = prob + p | |
probs.append(number(math.exp(prob))) | |
return probs | |
#-------------------------------------------------------------------------- | |
def getLeaves(self): | |
leaves = collections.deque([]) | |
queue = collections.deque([self.root]) | |
node_count = 0 | |
while len(queue) > 0: | |
node_count += 1 | |
node = queue.popleft() | |
children = node.children.values() | |
if not children: | |
leaves.append(node) | |
else: | |
queue.extend(children) | |
self.size = node_count | |
return leaves | |
def getSize(self): | |
queue = collections.deque([self.root]) | |
count = 0 | |
while len(queue) > 0: | |
count += 1 | |
queue.extend(queue.popleft().children.values() or []) | |
return count | |
def __str__(self): | |
s = collections.deque([]) | |
queue = collections.deque([self.root]) | |
while len(queue) > 0: | |
node = queue.popleft() | |
s.append(str(node) + ', path: ' + str([n.label for n in node.nodePath()])) | |
queue.extend(node.children.values() or []) | |
return "\n".join(s) | |
#-------------------------------------------------------------------------- | |
class Node: | |
''' | |
A node in a context tree. Contains a pointer | |
to the parent and a dictionary with the children. | |
''' | |
def __init__(self, label, parent=None, sprouts=None): | |
self.label = label | |
self.parent = parent | |
self.children = dict() | |
self.count = 0 | |
self.normalized = False | |
if parent: | |
self.depth = parent.depth + 1 | |
parent.children[label] = self | |
else: | |
self.depth = 0 | |
if sprouts: | |
self.sprouts = sprouts | |
else: | |
self.sprouts = dict() | |
def __repr__(self): | |
return "<Node %s, (%s, %s, %s), {%s}>" % (self.label, self.parent, self.depth, self.count, self.sprouts.keys()) | |
def __str__(self): | |
return "<Node %s, c: %d, d: %d, s: {%s}>" % (self.label, self.count, self.depth, self.sprouts.keys()) | |
def addSprout(self, label): | |
self.count += 1 | |
self.sprouts[label] = 1 + self.sprouts.get(label, 0) | |
def getSproutLogProbability(self, sprout): | |
if self.sprouts.has_key(sprout): | |
return number(math.log(self.sprouts[sprout]) - math.log(self.count)) | |
else: | |
return None #-1e50 # UNKNOWN sprout | |
def nodePath(self): | |
"Create a list of nodes from the root to this node." | |
x, result = self, [self] | |
while x.parent: | |
result.append(x.parent) | |
x = x.parent | |
result.reverse() | |
return result | |
#-------------------------------------------------------------------------- | |
def readset(f): | |
sett = [] | |
line = f.readline() | |
l = line.split(" ") | |
num_strings = int(l[0]) | |
alphabet_size = int(l[1]) | |
sett = [[int(i) for i in l.strip().split(" ")[1:]] for l in f.readlines()] | |
return alphabet_size, sett | |
def readprobs(f): | |
probs = [] | |
line = f.readline() | |
probs = [number(l.strip()) for l in f.readlines()] | |
return probs | |
def writeprobs(probs,f): | |
f.write(str(len(probs)) + "\n") | |
for i in range(len(probs)): | |
f.write(str(probs[i]) + "\n") | |
def normalize(arr): | |
sumarr = number(sum(arr)) | |
if sumarr != 0.0: | |
for i in range(len(arr)): | |
arr[i] = arr[i] / sumarr | |
def main(): | |
parser = argparse.ArgumentParser(description="Variable length n-gram for sequence modeling", epilog="Please contact [email protected] for additional help.") | |
parser.add_argument("-id", help="just printed to the output when in report mode") | |
parser.add_argument("-r", "--report", metavar="solution_file", nargs='?', help="turn on report mode: only some values are printed to the output, in the following order:\ | |
id, ngram, cut_value, density, len(train+test), vlngram_tree.depth, vlngram_tree.size, total_time, perplexity, test_file") | |
params_group = parser.add_argument_group('model parameters') | |
params_group.add_argument("max_ngram", nargs='?', default=4, type=int, help="the maximum length of sequences for building the tree") | |
params_group.add_argument("cut_value", nargs='?', default=0.001, type=float, help="the threshold for prunning the tree") | |
params_group.add_argument("density", nargs='?', default=1, type=int, help="the minimum number of occurrences of a node") | |
parser.add_argument("train_file", help="a file of sequences of symbols in the PAutomaC format") | |
parser.add_argument("test_file", help="a file of sequences of symbols in the PAutomaC format") | |
args = parser.parse_args() | |
ngram = args.max_ngram | |
cut_value = args.cut_value | |
density = args.density | |
train_file = args.train_file | |
test_file = args.test_file | |
alphabet, test = readset(open(test_file,"r")) | |
alphabet, train = readset(open(train_file,"r")) | |
start_time = time.clock() | |
vlngram_tree = VLNGramTree(max_depth=ngram, cut_value=cut_value, density=density) | |
vlngram_tree.train(train+test) | |
vlngram_probs = vlngram_tree.predictProbabilities(test) | |
normalize(vlngram_probs) | |
stop_time = time.clock() | |
total_time = stop_time - start_time | |
# In report mode, probabilities are not saved to a file | |
if not args.report: | |
suffix = ".vlngram" | |
prob_file = test_file+"-L"+str(ngram)+"K"+str(cut_value)+"d"+str(density)+suffix | |
print "Writing probabilities to '%s'" % prob_file | |
writeprobs(vlngram_probs, open(prob_file,"w")) | |
else: | |
solution_file = args.report | |
solution_probs = readprobs(open(solution_file,"r")) | |
exponent = [PrT * number(math.log(PrC, 2)) for PrT, PrC in zip(solution_probs, vlngram_probs) if PrC != 0.0] | |
perplexity = 2 ** -(sum(exponent)) | |
# Header: args.id, ngram, cut_value, density, len(train+test), vlngram_tree.depth, vlngram_tree.size, total_time, perplexity, test_file | |
print '\t'.join(map("{}".format, (args.id, ngram, cut_value, density, len(train+test), vlngram_tree.depth, vlngram_tree.size, total_time, perplexity, test_file))) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment