Created
September 24, 2017 04:28
-
-
Save sushlala/8eaeedd167aa877f8f8a38a1e94a7eeb to your computer and use it in GitHub Desktop.
a trie implemented in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import islice | |
from collections import defaultdict | |
from Queue import Queue | |
class Node(object): | |
'A node in our trie. Only nodes at which a key end will have a `val` attribute' | |
__slots__ = ('children', 'val') | |
def __init__(self): | |
self.children = defaultdict(self.__class__) | |
class Trie(object): | |
'Create a trie that may accessed like a dict' | |
def __init__(self): | |
self.root = Node() | |
def __setitem__(self, key, val): | |
'Traverse the trie, insert `val` at `key`' | |
keyIterable = key | |
curNode = self.root | |
# itetrate up till the second last keyIterable member | |
for key in islice(keyIterable, len(keyIterable) -1): | |
# if key not in children, children (defaultdict) insert a Node() | |
# at children[key] and return that | |
curNode = curNode.children[key] | |
# at the terminalNode, add the value | |
curNode = curNode.children[keyIterable[-1]] | |
curNode.val = val | |
def __getitem__(self, key): | |
'Traverse the trie, get `val` stored at `key`' | |
keyIterable = key | |
curNode = self.root | |
for key in islice(keyIterable, len(keyIterable) -1): | |
if key not in curNode.children: | |
raise KeyError | |
curNode = curNode.children[key] | |
key = keyIterable[-1] | |
if key not in curNode.children: | |
raise KeyError | |
lastNode = curNode.children[key] | |
if not hasattr(lastNode, 'val'): # this node was not a terminal Node. | |
raise KeyError | |
return lastNode.val | |
def __contains__(self, key): | |
'Traverse the trie, find out if `key` is present in it' | |
try: | |
self[key] | |
except KeyError: | |
return False | |
return True | |
def key_starts_with(self, prefix): | |
' return keys in the trie that start with the prefix' | |
# get to node that represents the last element in prefix | |
curNode = self.root | |
for key in prefix: | |
if key not in curNode.children: | |
raise StopIteration # we found no keys with that prefix | |
curNode = curNode.children[key] | |
# perform BFS from curNode: curNode represents end of prefix | |
q = Queue() | |
q.put((prefix, curNode)) | |
while not q.empty(): | |
key, node = q.get() | |
for letter, childNode in node.children.iteritems(): | |
q.put((key+letter, childNode)) | |
if hasattr(node, 'val'): | |
yield key | |
if __name__ == '__main__': | |
a = Trie() | |
# lets add some items to our trie | |
for k,v in [ ('abc', 1), ('abcdef', 2), ('abcdefgh', 3), ('zzz', 6) ]: | |
a[k] = v | |
print 'does the trie contain key `ppp`?', 'ppp' in a | |
print 'does the trie contain key `abc`?', 'abc' in a | |
print 'value of trie key `zzz`', a['zzz'] | |
print 'keys that start with `abcde` in trie', list(a.key_starts_with('abcde')) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment