Created
December 14, 2018 16:04
-
-
Save brad-anton/0255e06ba004f800dc710d72d9442f5e to your computer and use it in GitHub Desktop.
Trie Class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
trie.py | |
@brad_anton | |
Go back to college, kid! | |
""" | |
class Node(object): | |
def __init__(self, value, parent=None): | |
self.value = value | |
self.parent = parent | |
self.children = [] | |
def __repr__(self): | |
parent = None | |
if self.parent is not None: | |
parent = self.parent.value | |
return "<Node value='{}', parent='{}', children_count='{}'>".format(self.value, parent, len(self.children)) | |
class Trie(object): | |
INT = ':' | |
LEAF = ',' | |
def __init__(self, verbose=False): | |
self.head = Node('head') | |
self.verbose = verbose | |
def add_single(self, single, head): | |
for child in head.children: | |
if child.value == single: | |
if self.verbose: | |
print 'Found existing node for {}, no need to create a new one'.format(single) | |
return child | |
if self.verbose: | |
print 'Creating Node: {}'.format(single) | |
n = Node(single, parent=head) | |
head.children.append(n) | |
return n | |
def add_word(self, word): | |
if self.verbose: | |
print 'Adding word: {}'.format(word) | |
head = self.head | |
for char in word: | |
head = self.add_single(char, head) | |
def get_single(self, value, head): | |
if not head.children: | |
return None | |
for child in head.children: | |
if child.value == value: | |
return child | |
return None | |
def find_word(self, word): | |
head = self.head | |
for char in word: | |
head = self.get_single(char, head) | |
if head is None: | |
return False | |
return True | |
def get_parts(self, domain): | |
# Need to do this because Guava includes dots in | |
# its serialize output :( | |
parts = [] | |
d = domain[::-1] | |
p = d.split('.') | |
end = len(p) | |
for i in range(0, end): | |
part = p[i] | |
# For readability | |
one_part = ( end == 1 ) | |
first_label = ( i == 0 and end > 1 ) | |
interior = ( i > 0 and i < end - 1 ) | |
if ( not one_part and first_label or interior ): | |
part = '{}.'.format(p[i]) | |
parts.append(part) | |
return parts | |
def add_domain(self, domain): | |
if self.verbose: | |
print 'Adding domain: {}'.format(domain) | |
head = self.head | |
parts = self.get_parts(domain) | |
for part in parts: | |
head = self.add_single(part, head) | |
def add_domains(self, domains): | |
for d in domains: | |
self.add_domain(d) | |
def find_domain(self, domain): | |
head = self.head | |
parts = self.get_parts(domain) | |
for part in parts: | |
head = self.get_single(part, head) | |
if head is None: | |
return False | |
return True | |
def recurse(self, branch, node): | |
branch += node.value | |
if not node.children: | |
return branch + self.LEAF | |
branch += self.INT | |
for child in node.children: | |
branch = self.recurse(branch, child) | |
return branch + self.LEAF | |
def serialize(self): | |
head = self.head | |
branches = [] | |
for h_child in self.head.children: | |
branch = self.recurse('', h_child) | |
# Branches end with two LEAFs | |
branch += self.LEAF | |
branches.append(branch) | |
return ''.join(branches) | |
@staticmethod | |
def process_serialized(stack, encoded, verbose=False): | |
t = Trie() | |
encodedLen = len(encoded) | |
if verbose: | |
print 'encodedLen: {}'.format(encodedLen) | |
print 'Stack: {}'.format(stack) | |
c = '\0' | |
idx = 0 | |
for idx in range(0, encodedLen): | |
c = encoded[idx] | |
# Read all chars up until we encounter a control character | |
if c == '&' or c == '?' or c == '!' or c == ':' or c == ',': | |
if verbose: | |
print 'Got control char: {}'.format(c) | |
break; | |
# Add all characters up to the control character onto the stack | |
stack.append(encoded[0:idx]) | |
if verbose: | |
print 'Stack (after append): {}'.format(stack) | |
if c == '?' or c == '!' or c == ':' or c == ',': | |
domain = ''.join(stack) | |
if verbose: | |
print 'Candidate: {}'.format(domain) | |
if len(domain) > 0: | |
print 'Adding domain: {}'.format(domain[::-1]) | |
if verbose: | |
print 'Incrementing idx (1): {}, encodedLen: {}, encoded: {}'.format(idx, encodedLen, encoded[idx:]) | |
# Continue past control character | |
idx += 1 | |
# Process interior nodes | |
if c != '?' and c != ',': | |
while idx < encodedLen: | |
idx += Trie.process_serialized(stack, encoded[idx:]) | |
if encoded[idx] == '?' or encoded[idx] == ',': | |
# End of branch? | |
if verbose: | |
print 'Incrementing idx (2): {}, encodedLen: {}, encoded: {}'.format(idx, encodedLen, encoded[idx:]) | |
idx += 1 | |
break; | |
stack.pop() | |
if verbose: | |
print 'Stack (end): {}'.format(stack) | |
return idx; | |
@staticmethod | |
def deserialize(encoded): | |
encodedLen = len(encoded) | |
idx = 0 | |
while idx < encodedLen: | |
idx += Trie.process_serialized([], encoded[idx:] ) | |
# TODO: Return the Trie :) | |
return None | |
if __name__ == '__main__': | |
domains = ['test.com', 'test2.com', 'test.org', 'test.test.org' ] | |
t = Trie(verbose=True) | |
t.add_domains(domains) | |
s = t.serialize() | |
print domains | |
print s | |
Trie.deserialize(s) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment