Skip to content

Instantly share code, notes, and snippets.

@rgov
Last active December 15, 2015 15:39
Show Gist options
  • Select an option

  • Save rgov/5283580 to your computer and use it in GitHub Desktop.

Select an option

Save rgov/5283580 to your computer and use it in GitHub Desktop.
Generates a directed acyclic word graph for efficient membership testing for a word list.
#!/usr/bin/env python
'''
An implementation of "Incremental Construction of Minimal Acyclic Finite-State
Automata" by J. Daciuk, et al.
Data structures also inspired by "The World's Fastest Scrabble Program" by
A. W. Appel & G. J. Jacobson.
Ryan Govostes, 1 April 2013
'''
class Node(object):
def __init__(self, final=False, edges=None):
self.final = final
self.edges = [] if edges is None else edges
self.marker = 0
self.updatehash()
def __eq__(self, rhs):
'''
Determines whether the nodes belong to the same equivalence class:
1 they are either both final or both non-final; and
2 they have the same number of outgoing transitions; and
3 corresponding outgoing transitions have the same labels; and
4' corresponding transitions lead to the same states.
'''
return (self.hash == rhs.hash) and \
(self.final == rhs.final) and (self.edges == rhs.edges)
def __hash__(self):
return self.hash
def updatehash(self):
'''
For performance, we can short-circuit equivalence testing by first comparing
the hash value of the edge list, which needs to be kept up-to-date as the
edges are modified.
'''
self.hash = hash(tuple(self.edges))
def indexof(self, value):
'''
Bisects the edge list to find the index of the edge with the given value.
Returns -1 if the edge is not found.
'''
lo, hi = 0, len(self.edges)
while lo < hi:
mid = (lo + hi) // 2
if self.edges[mid].value < value:
lo = mid + 1
else:
hi = mid
if lo != len(self.edges) and self.edges[lo].value == value:
return lo
else:
return -1
class Edge(object):
def __init__(self, value, to):
self.value, self.to = value, to
def __eq__(self, rhs):
return (self.value == rhs.value) and (self.to is rhs.to)
def __hash__(self):
return hash((self.value, id(self.to)))
class Dictionary(object):
def __init__(self):
self.root = Node()
self.prevword = None
self.marker = 0
# The register is mapping from nodes to themselves, which will be used to
# find the register entry of the same equivalence class, if it exists. We
# have overridden Node.__hash__ and Node.__eq__ to support this. Note that
# once in the registry, a node's edge list is not modified, so it has a
# stable hash.
self.register = {}
def commonprefix(self, word):
'''
Finds the longest prefix of word that is accepted by the current DFA.
Returns the tuple (cursor, prefix, suffix).
'''
cursor = self.root
for i, c in enumerate(word):
j = cursor.indexof(c)
if j == -1:
return (cursor, word[:i], word[i:])
cursor = cursor.edges[j].to
def replaceorregister(self, node):
'''
If the last-added child of this node is equivalent to one in the registry,
replace the child with the registry entry. Otherwise, register the child.
'''
if len(node.edges) == 0:
return
lastedge = node.edges[-1]
child = lastedge.to
# Apply recursively to the grandchildren
self.replaceorregister(child)
# Look for a node in the register with the same class as the last child,
# and simply replace the last child with the register node
q = self.register.get(child)
if q is not None:
node.edges[-1] = Edge(lastedge.value, q)
# Would not be needed if Edge.__hash__ were not based on object id, but
# otherwise there is a circular dependency between it and Node.__hash__
node.updatehash()
else:
# Otherwise, this child goes in the register
self.register[child] = child
def addword(self, word):
'''
Adds a new word to the dictionary. The word must be follow the previous
addition in lexicographic order.
'''
if word < self.prevword:
raise Exception('Words must be added in lexicographic order')
self.prevword = word
last, _, suffix = self.commonprefix(word)
self.replaceorregister(last)
self.addsuffix(last, suffix)
def addsuffix(self, cursor, suffix):
'''
Inserts the nodes and edges needed to accept the suffix.
'''
for c in suffix:
n = Node()
cursor.edges.append(Edge(value=c, to=n))
cursor.updatehash()
cursor = n
cursor.final = True
def finalize(self):
'''
Performs a final replace-or-register step starting from the root node.
'''
self.replaceorregister(self.root)
def bfs(self):
'''
A generator that yields nodes in breadth-first order.
'''
# Use a different marker each time, rather than have to clear afterwards
self.marker += 1
queue = [ self.root ]
while len(queue):
cursor = queue.pop()
if cursor.marker == self.marker:
continue
cursor.marker = self.marker
yield cursor
for e in cursor.edges:
queue.append(e.to)
def GenerateIndex(graph):
'''
Determines the index for the start of the edge list for each node. Returns
the one greater than the last index used. Nodes with no outgoing edges get
the index ~0 (there should be only one of these).
'''
nextid = 0
for n in graph.bfs():
n.index = nextid if len(n.edges) > 0 else 0x001FFFFF
nextid += len(n.edges)
return nextid
def DotPrinter(graph):
'''
Outputs the graph in DOT format, for rendering with Graphviz.
'''
GenerateIndex(graph)
out = 'digraph dawg {\n'
for n in graph.bfs():
shape = 'doublecircle' if n.final else 'circle'
out += '\tN{0} [label={0}, shape={1}]\n'.format(n.index, shape)
for n in graph.bfs():
for e in n.edges:
out += '\tN{0} -> N{1} [label={2}]\n'.format(n.index, e.to.index, e.value)
out += '}\n'
return out
def BinaryPrinter(graph):
'''
Pack each edge into a 32-bit integer, with the following fields:
1 bit: the target node for this edge is final
5 bits: if first edge, # of edges, else 0
5 bits: the value of this edge (a = 0, ..., z = 25)
21 bits: the index of the first edge of the target node
'''
# To avoid a dependency on the struct module
uint32 = lambda x: ''.join(chr((x >> sh) & 0xFF) for sh in (24, 16, 8, 0))
pack = lambda f, n, v, o: uint32((f << 31) | (n << 26) | (v << 21) | o)
# Determine the index for the start of the edge list of each node
nextid = GenerateIndex(graph)
if nextid >= (1 << 21):
raise Exception('Too many nodes to pack into 32 bits')
# Pack each edge in order
packed = []
for n in graph.bfs():
for i, e in enumerate(n.edges):
f = 1 if e.to.final else 0
n = len(n.edges) if i == 0 else 0
v = ord(e.value) - ord('a')
o = e.to.index
packed.append(pack(f, n, v, o))
return ''.join(packed)
if __name__ == '__main__':
import sys
if len(sys.argv) == 2 and sys.argv[1] == '--dot':
printer = DotPrinter
elif len(sys.argv) == 2 and sys.argv[1] == '--binary':
printer = BinaryPrinter
else:
print >> sys.stderr, "Specify output format: --dot, --binary"
sys.exit(1)
# Read each word from stdin and add it to the dictionary
d = Dictionary()
for w in sys.stdin:
d.addword(w.rstrip().lower())
d.finalize()
sys.stdout.write(printer(d))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment