-
-
Save hiratara/3122254 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import bisect | |
class NFA(object): | |
EPSILON = object() | |
ANY = object() | |
def __init__(self, start_state): | |
self.transitions = {} | |
self.final_states = set() | |
self._start_state = start_state | |
@property | |
def start_state(self): | |
return frozenset(self._expand(set([self._start_state]))) | |
def add_transition(self, src, input, dest): | |
self.transitions.setdefault(src, {}).setdefault(input, set()).add(dest) | |
def add_final_state(self, state): | |
self.final_states.add(state) | |
def is_final(self, states): | |
return self.final_states.intersection(states) | |
def _expand(self, states): | |
frontier = set(states) | |
while frontier: | |
state = frontier.pop() | |
new_states = self.transitions.get(state, {}).get(NFA.EPSILON, set()).difference(states) | |
frontier.update(new_states) | |
states.update(new_states) | |
return states | |
def next_state(self, states, input): | |
dest_states = set() | |
for state in states: | |
state_transitions = self.transitions.get(state, {}) | |
dest_states.update(state_transitions.get(input, [])) | |
dest_states.update(state_transitions.get(NFA.ANY, [])) | |
return frozenset(self._expand(dest_states)) | |
def get_inputs(self, states): | |
inputs = set() | |
for state in states: | |
inputs.update(self.transitions.get(state, {}).keys()) | |
return inputs | |
def to_dfa(self): | |
dfa = DFA(self.start_state) | |
frontier = [self.start_state] | |
seen = set() | |
while frontier: | |
current = frontier.pop() | |
inputs = self.get_inputs(current) | |
for input in inputs: | |
if input == NFA.EPSILON: continue | |
new_state = self.next_state(current, input) | |
if new_state not in seen: | |
frontier.append(new_state) | |
seen.add(new_state) | |
if self.is_final(new_state): | |
dfa.add_final_state(new_state) | |
if input == NFA.ANY: | |
dfa.set_default_transition(current, new_state) | |
else: | |
dfa.add_transition(current, input, new_state) | |
return dfa | |
class DFA(object): | |
def __init__(self, start_state): | |
self.start_state = start_state | |
self.transitions = {} | |
self.defaults = {} | |
self.final_states = set() | |
def add_transition(self, src, input, dest): | |
self.transitions.setdefault(src, {})[input] = dest | |
def set_default_transition(self, src, dest): | |
self.defaults[src] = dest | |
def add_final_state(self, state): | |
self.final_states.add(state) | |
def is_final(self, state): | |
return state in self.final_states | |
def next_state(self, src, input): | |
state_transitions = self.transitions.get(src, {}) | |
return state_transitions.get(input, self.defaults.get(src, None)) | |
def run(self, input): | |
state = self.start_state | |
# Evaluate the DFA as far as possible | |
for i, x in enumerate(input): | |
state = self.next_state(state, x) | |
if not state: return False | |
if self.is_final(state): | |
# Input word is already valid | |
return True | |
return False | |
def levenshtein_automata(term, k): | |
nfa = NFA((0, 0)) | |
for i, c in enumerate(term): | |
for e in range(k + 1): | |
# Correct character | |
nfa.add_transition((i, e), c, (i + 1, e)) | |
if e < k: | |
# Deletion | |
nfa.add_transition((i, e), NFA.ANY, (i, e + 1)) | |
# Insertion | |
nfa.add_transition((i, e), NFA.EPSILON, (i + 1, e + 1)) | |
# Substitution | |
nfa.add_transition((i, e), NFA.ANY, (i + 1, e + 1)) | |
for e in range(k + 1): | |
if e < k: | |
nfa.add_transition((len(term), e), NFA.ANY, (len(term), e + 1)) | |
nfa.add_final_state((len(term), e)) | |
return nfa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This code is copied from http://d.hatena.ne.jp/naoya/20090329/1238307757 | |
def levenshtein_distance(a, b): | |
m = [ [0] * (len(b) + 1) for i in range(len(a) + 1) ] | |
for i in xrange(len(a) + 1): | |
m[i][0] = i | |
for j in xrange(len(b) + 1): | |
m[0][j] = j | |
for i in xrange(1, len(a) + 1): | |
for j in xrange(1, len(b) + 1): | |
if a[i - 1] == b[j - 1]: | |
x = 0 | |
else: | |
x = 1 | |
m[i][j] = min(m[i - 1][j] + 1, m[i][ j - 1] + 1, m[i - 1][j - 1] + x) | |
# print m | |
return m[-1][-1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import timeit | |
import automata | |
import distance | |
string = "levenshtein" | |
ambiguity = 2 | |
word_num = 100 | |
repeat_num = 100 | |
# Create test data | |
def smudge_str(string, depth): | |
chars = list(string) | |
for i in xrange(0, depth): | |
n = random.randint(0, len(chars) - 1) | |
chars[n] = chr(random.randrange(ord('a'), ord('z'))) | |
return "".join(chars) | |
inputs = [smudge_str(string, ambiguity + 1) for i in xrange(0, word_num)] | |
# Create the automata logic | |
nfa = automata.levenshtein_automata(string, ambiguity) | |
dfa = nfa.to_dfa() | |
def automata_func(str): return dfa.run(str) | |
# Create the distance logic | |
def distance_func(str): | |
return distance.levenshtein_distance(string, str) <= ambiguity | |
# Run benchmarks | |
for func in [automata_func, distance_func]: | |
def run_all(): | |
for input in inputs: func(input) | |
timeit.__dict__.update(run_all=run_all) | |
timer = timeit.Timer("run_all()") | |
print timer.timeit(repeat_num) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment