Skip to content

Instantly share code, notes, and snippets.

@jedie
Created May 3, 2018 10:47
Show Gist options
  • Save jedie/523dc31dbbd22968b58a0f4798876e3c to your computer and use it in GitHub Desktop.
Save jedie/523dc31dbbd22968b58a0f4798876e3c to your computer and use it in GitHub Desktop.
import collections
class SimpleTree(collections.defaultdict):
"""
>>> tree = SimpleTree()
>>> tree.add(keys=("1",), value="value 1")
>>> tree.add(keys=("2a", "2b"), value="value 2.1")
>>> tree.add(keys=("2a", "2b"), value="value 2.2")
>>> tree.split_add("3a.3b.3c", "value 3.1")
>>> tree.split_add("3a.3b.3c", "value 3.2")
>>> tree.split_add("3a.3b.3c", "value 3.3")
>>> tree.pprint(indent=2, print_depth=True)
0 <key> 1
1 <value> value 1
0 <key> 2a
1 <key> 2b
2 <value> value 2.1
2 <value> value 2.2
0 <key> 3a
1 <key> 3b
2 <key> 3c
3 <value> value 3.1
3 <value> value 3.2
3 <value> value 3.3
>>> tree.split_add("4a-4b-4c-4d", "value 4.1", split_by="-")
>>> tree.split_add("4a-4b-4c-4d", "value 4.2", split_by="-")
>>> tree.split_add("4a-4b-4c-4d", "value 4.3", split_by="-")
>>> tree.split_add("4a-4b-4c-4d", "value 4.4", split_by="-")
>>> for keys, values in tree.iter_values():
... print(keys, values)
['1'] ['value 1']
['2a', '2b'] ['value 2.1', 'value 2.2']
['3a', '3b', '3c'] ['value 3.1', 'value 3.2', 'value 3.3']
['4a', '4b', '4c', '4d'] ['value 4.1', 'value 4.2', 'value 4.3', 'value 4.4']
"""
def __init__(self):
super().__init__(dict)
self.__type = type(self)
def __missing__(self, key):
value = self[key] = self.__type() # retain local pointer to value
return value # faster to return than dict lookup
def add(self, keys, value=None, child=None):
first_key = keys[0]
if child is not None:
obj = child[first_key]
else:
obj = self[first_key]
if len(keys) == 1:
if first_key in obj:
obj[first_key].append(value)
else:
obj[first_key] = [value]
else:
self.add(keys[1:], value=value, child=obj)
def split_add(self, keys, value, split_by="."):
keys = keys.split(split_by)
self.add(keys, value)
def __iter__(self, keys=None, obj=None, depth=1):
if keys is None:
keys = []
if obj is None:
obj = self
for key, values in sorted(obj.items()):
if isinstance(values, self.__type):
yield (keys, key, True)
yield from self.__iter__(keys + [key], obj=values, depth=depth + 1)
else:
for value in values:
yield (keys, value, False)
def iter_values(self, keys=None, obj=None, depth=1):
if keys is None:
keys = []
if obj is None:
obj = self
for key, values in sorted(obj.items()):
if isinstance(values, self.__type):
yield from self.iter_values(keys + [key], obj=values, depth=depth + 1)
else:
yield (keys, values)
def pprint(self, indent=4, print_depth=False):
for keys, obj, is_key in self:
depth = len(keys)
indent_str = " " * (depth * indent)
obj_type = "<key>" if is_key else "<value>"
txt = "%7s %s %s" % (obj_type, indent_str, obj)
if print_depth:
print(depth, txt)
else:
print(txt)
if __name__ == '__main__':
import doctest
print(doctest.testmod())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment