Created
May 8, 2011 22:43
-
-
Save tyler/961773 to your computer and use it in GitHub Desktop.
serializing trie
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from struct import * | |
def _node_size(child_count): | |
return child_count * 5 + 6 | |
class Node(object): | |
def __init__(self): | |
self.children = {} | |
self.terminal = False | |
self.value = None | |
@staticmethod | |
def deserialize(serialized_trie): | |
return Node._deserialize_node(serialized_trie, 0) | |
@staticmethod | |
def _deserialize_node(data, offset): | |
node = Node() | |
size, value = unpack('=HI', data[offset : offset + 6]) | |
node.terminal = True if value != 0 else False | |
node.value = None if value == 0 else value | |
child_count = (size - 6) / 5 | |
for child_idx in range(child_count): | |
child_start = offset + 6 + (child_idx * 5) | |
k, child_offset = unpack('=BI', data[child_start : child_start + 5]) | |
child_node = Node._deserialize_node(serialized_trie, child_offset) | |
node.children[k] = child_node | |
return node | |
def serialize(self): | |
total_length, output = self._serialize_node(0) | |
return output | |
def _serialize_node(self, starting_offset): | |
nodesize = _node_size(len(self.children)) | |
value = self.value or 0 | |
output = pack('=HI', nodesize, value) | |
next_child_offset = starting_offset + nodesize | |
serialized_children = '' | |
for k in self.children: | |
output += pack('=BI', k, next_child_offset) | |
child = self.children[k] | |
next_child_offset, serialized_child = child._serialize_node(next_child_offset) | |
serialized_children += serialized_child | |
return next_child_offset, output + serialized_children | |
def __setitem__(self, key, value): | |
if type(value) != int: | |
raise TypeError, 'Expected int as value' | |
if type(key) == str: | |
self._add(bytearray(key), value) | |
elif type(key) == unicode: | |
self._add(bytearray(key, 'UTF-8'), value) # assumes UTF-8 | |
elif type(key) == bytearray: | |
self._add(key, value) | |
else: | |
raise TypeError, 'Expected string, unicode, or bytearray as key' | |
def __getitem__(self, key): | |
value = self._retrieve(key) | |
if value == None: | |
raise IndexError | |
else: | |
return value | |
def __contains__(self, key): | |
if self._retrieve(key) == None: | |
return False | |
else: | |
return True | |
def _add(self, key, value): | |
if len(key) > 0: | |
byte = key[0] | |
rest = key[1:] | |
if byte not in self.children: | |
self.children[byte] = Node() | |
child = self.children[byte] | |
child._add(rest, value) | |
else: | |
self.terminal = True | |
self.value = value | |
def _retrieve(self, key): | |
if len(key) > 0: | |
byte = ord(key[0]) | |
rest = key[1:] | |
if byte in self.children: | |
child = self.children[byte] | |
return child._retrieve(rest) | |
else: | |
return None | |
else: | |
if self.terminal: | |
return self.value | |
else: | |
return None | |
if __name__ == '__main__': | |
trie = Node() | |
trie['monkey'] = 1 | |
trie['monk'] = 2 | |
trie['monkeys'] = 3 | |
trie['foobar'] = 4 | |
print trie['monkey'] | |
print trie['monk'] | |
print trie['monkeys'] | |
print trie['foobar'] | |
#trie = Node() | |
#test_dictionary = open('test_dict_small', 'r') | |
#for line in test_dictionary: | |
# word, count = line.split("\t") | |
# trie[word] = int(count) | |
print "-- Serializing..." | |
serialized_trie = trie.serialize() | |
print "-- Deserializing..." | |
trie = Node.deserialize(serialized_trie) | |
print trie['monkey'] | |
print trie['monk'] | |
print trie['monkeys'] | |
print trie['foobar'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment