Skip to content

Instantly share code, notes, and snippets.

@KrishnaPG
Forked from Arachnid/automata.py
Created August 6, 2017 15:20
Show Gist options
  • Save KrishnaPG/e6f9f38069173f6331829a589cdad61c to your computer and use it in GitHub Desktop.
Save KrishnaPG/e6f9f38069173f6331829a589cdad61c to your computer and use it in GitHub Desktop.
import bisect
class NFA(object):
EPSILON = object()
ANY = object()
def __init__(self, start_state):
self.transitions = {}
self.final_states = set()
self._start_state = start_state
@property
def start_state(self):
return frozenset(self._expand(set([self._start_state])))
def add_transition(self, src, input, dest):
self.transitions.setdefault(src, {}).setdefault(input, set()).add(dest)
def add_final_state(self, state):
self.final_states.add(state)
def is_final(self, states):
return self.final_states.intersection(states)
def _expand(self, states):
frontier = set(states)
while frontier:
state = frontier.pop()
new_states = self.transitions.get(state, {}).get(NFA.EPSILON, set()).difference(states)
frontier.update(new_states)
states.update(new_states)
return states
def next_state(self, states, input):
dest_states = set()
for state in states:
state_transitions = self.transitions.get(state, {})
dest_states.update(state_transitions.get(input, []))
dest_states.update(state_transitions.get(NFA.ANY, []))
return frozenset(self._expand(dest_states))
def get_inputs(self, states):
inputs = set()
for state in states:
inputs.update(self.transitions.get(state, {}).keys())
return inputs
def to_dfa(self):
dfa = DFA(self.start_state)
frontier = [self.start_state]
seen = set()
while frontier:
current = frontier.pop()
inputs = self.get_inputs(current)
for input in inputs:
if input == NFA.EPSILON: continue
new_state = self.next_state(current, input)
if new_state not in seen:
frontier.append(new_state)
seen.add(new_state)
if self.is_final(new_state):
dfa.add_final_state(new_state)
if input == NFA.ANY:
dfa.set_default_transition(current, new_state)
else:
dfa.add_transition(current, input, new_state)
return dfa
class DFA(object):
def __init__(self, start_state):
self.start_state = start_state
self.transitions = {}
self.defaults = {}
self.final_states = set()
def add_transition(self, src, input, dest):
self.transitions.setdefault(src, {})[input] = dest
def set_default_transition(self, src, dest):
self.defaults[src] = dest
def add_final_state(self, state):
self.final_states.add(state)
def is_final(self, state):
return state in self.final_states
def next_state(self, src, input):
state_transitions = self.transitions.get(src, {})
return state_transitions.get(input, self.defaults.get(src, None))
def next_valid_string(self, input):
state = self.start_state
stack = []
# Evaluate the DFA as far as possible
for i, x in enumerate(input):
stack.append((input[:i], state, x))
state = self.next_state(state, x)
if not state: break
else:
stack.append((input[:i+1], state, None))
if self.is_final(state):
# Input word is already valid
return input
# Perform a 'wall following' search for the lexicographically smallest
# accepting state.
while stack:
path, state, x = stack.pop()
x = self.find_next_edge(state, x)
if x:
path += x
state = self.next_state(state, x)
if self.is_final(state):
return path
stack.append((path, state, None))
return None
def find_next_edge(self, s, x):
if x is None:
x = u'\0'
else:
x = unichr(ord(x) + 1)
state_transitions = self.transitions.get(s, {})
if x in state_transitions or s in self.defaults:
return x
labels = sorted(state_transitions.keys())
pos = bisect.bisect_left(labels, x)
if pos < len(labels):
return labels[pos]
return None
def levenshtein_automata(term, k):
nfa = NFA((0, 0))
for i, c in enumerate(term):
for e in range(k + 1):
# Correct character
nfa.add_transition((i, e), c, (i + 1, e))
if e < k:
# Deletion
nfa.add_transition((i, e), NFA.ANY, (i, e + 1))
# Insertion
nfa.add_transition((i, e), NFA.EPSILON, (i + 1, e + 1))
# Substitution
nfa.add_transition((i, e), NFA.ANY, (i + 1, e + 1))
for e in range(k + 1):
if e < k:
nfa.add_transition((len(term), e), NFA.ANY, (len(term), e + 1))
nfa.add_final_state((len(term), e))
return nfa
def find_all_matches(word, k, lookup_func):
"""Uses lookup_func to find all words within levenshtein distance k of word.
Args:
word: The word to look up
k: Maximum edit distance
lookup_func: A single argument function that returns the first word in the
database that is greater than or equal to the input argument.
Yields:
Every matching word within levenshtein distance k from the database.
"""
lev = levenshtein_automata(word, k).to_dfa()
match = lev.next_valid_string(u'\0')
while match:
next = lookup_func(match)
if not next:
return
if match == next:
yield match
next = next + u'\0'
match = lev.next_valid_string(next)
import automata
import bisect
import random
class Matcher(object):
def __init__(self, l):
self.l = l
self.probes = 0
def __call__(self, w):
self.probes += 1
pos = bisect.bisect_left(self.l, w)
if pos < len(self.l):
return self.l[pos]
else:
return None
words = [x.strip().lower().decode('utf-8') for x in open('/usr/share/dict/web2')]
words.sort()
words10 = [x for x in words if random.random() <= 0.1]
words100 = [x for x in words if random.random() <= 0.01]
m = Matcher(words)
assert len(list(automata.find_all_matches('food', 1, m))) == 18
print m.probes
m = Matcher(words)
assert len(list(automata.find_all_matches('food', 2, m))) == 283
print m.probes
def levenshtein(s1, s2):
if len(s1) < len(s2):
return levenshtein(s2, s1)
if not s1:
return len(s2)
previous_row = xrange(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1 # j+1 instead of j since previous_row and current_row are one character longer
deletions = current_row[j] + 1 # than s2
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
class BKNode(object):
def __init__(self, term):
self.term = term
self.children = {}
def insert(self, other):
distance = levenshtein(self.term, other)
if distance in self.children:
self.children[distance].insert(other)
else:
self.children[distance] = BKNode(other)
def search(self, term, k, results=None):
if results is None:
results = []
distance = levenshtein(self.term, term)
counter = 1
if distance <= k:
results.append(self.term)
for i in range(max(0, distance - k), distance + k + 1):
child = self.children.get(i)
if child:
counter += child.search(term, k, results)
return counter
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment