Last active
August 29, 2015 14:01
-
-
Save andlima/e823ba48f9a03743f36c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
class Trie(object): | |
'''Implementation of a trie.''' | |
def __init__(self, collection=None): | |
self.ends = False | |
self.children = {} | |
if collection is not None: | |
for s in collection: | |
self.add(s) | |
def add(self, s): | |
return self._add_last_chars(s, 0) | |
def remove(self, s): | |
return self._remove_last_chars(s, 0) | |
def clear(self): | |
self.ends = False | |
self.children = {} | |
def __contains__(self, s): | |
return self._contains_last_chars(s, 0) | |
def __iter__(self): | |
for s in self._iter_with_prefix(''): | |
yield s | |
def __len__(self): | |
return self.ends + sum(len(child) for child in self.children.values()) | |
def _add_last_chars(self, s, i): | |
if len(s) == i: | |
if self.ends: | |
return False | |
self.ends = True | |
return True | |
c = s[i] | |
if c not in self.children: | |
self.children[c] = Trie() | |
return self.children[c]._add_last_chars(s, i+1) | |
def _remove_last_chars(self, s, i): | |
if len(s) == i: | |
if not self.ends: | |
return False | |
self.ends = False | |
return True | |
c = s[i] | |
if c not in self.children: | |
return False | |
return self.children[c]._remove_last_chars(s, i+1) | |
def _contains_last_chars(self, s, i): | |
if len(s) == i: | |
return self.ends | |
trie = self.children.get(s[i]) | |
if trie is None: | |
return False | |
return trie._contains_last_chars(s, i+1) | |
def _iter_with_prefix(self, prefix): | |
if self.ends: | |
yield prefix | |
for c, child in sorted(self.children.items()): | |
for s in child._iter_with_prefix(prefix + c): | |
yield s | |
if __name__ == '__main__': | |
trie = Trie() | |
assert len(trie) == 0 | |
assert '' not in trie | |
assert 'a' not in trie | |
assert 'b' not in trie | |
assert 'abc' not in trie | |
assert trie.add('a') | |
assert len(trie) == 1 | |
assert '' not in trie | |
assert 'a' in trie | |
assert 'b' not in trie | |
assert 'abc' not in trie | |
assert trie.add('') | |
assert len(trie) == 2 | |
assert '' in trie | |
assert 'a' in trie | |
assert 'b' not in trie | |
assert 'abc' not in trie | |
assert trie.add('b') | |
assert len(trie) == 3 | |
assert '' in trie | |
assert 'a' in trie | |
assert 'b' in trie | |
assert 'abc' not in trie | |
assert trie.add('abc') | |
assert len(trie) == 4 | |
assert '' in trie | |
assert 'a' in trie | |
assert 'b' in trie | |
assert 'aa' not in trie | |
assert 'ab' not in trie | |
assert 'abc' in trie | |
assert not trie.add('') | |
assert not trie.add('abc') | |
assert not trie.add('b') | |
assert not trie.add('a') | |
assert len(trie) == 4 | |
assert ['', 'a', 'abc', 'b'] == [s for s in trie] | |
assert trie.remove('a') | |
assert len(trie) == 3 | |
assert ['', 'abc', 'b'] == [s for s in trie] | |
assert not trie.remove('a') | |
assert len(trie) == 3 | |
assert ['', 'abc', 'b'] == [s for s in trie] | |
trie.clear() | |
assert len(trie) == 0 | |
assert [] == [s for s in trie] | |
trie2 = Trie(['x', 'pq', 'xa']) | |
assert ['pq', 'x', 'xa'] == [s for s in trie2] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment