Last active
March 12, 2019 01:25
-
-
Save fbparis/e114958a4a9be84146a1a579cf677e8d to your computer and use it in GitHub Desktop.
A very powerful and memory safe data-structure to replace python's dict: Indexed Radix Trie
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A Python3 indexed trie class. | |
An indexed trie's key can be any subscriptable object. | |
Keys of the indexed trie are stored using a "radix trie", a space-optimized data-structure which has many advantages (see https://en.wikipedia.org/wiki/Radix_tree). | |
Also, each key in the indexed trie is associated to a unique index which is build dynamically. | |
Indexed trie is used like a python dictionary (and even a collections.defaultdict if you want to) but its values can also be accessed or updated (but not created) like a list! | |
Example: | |
>>> t = indextrie() | |
>>> t["abc"] = "hello" | |
>>> t[0] | |
'hello' | |
>>> t["abc"] | |
'hello' | |
>>> t.index2key(0) | |
'abc' | |
>>> t.key2index("abc") | |
0 | |
>>> t[:] | |
[0] | |
>>> print(t) | |
{(0, 'abc'): hello} | |
""" | |
__author__ = "@fbparis" | |
_SENTINEL = object() | |
class _Node(object): | |
""" | |
A single node in the trie. | |
""" | |
__slots__ = "_children", "_parent", "_index", "_key" | |
def __init__(self, key, parent, index=None): | |
self._children = set() | |
self._key = key | |
self._parent = parent | |
self._index = index | |
self._parent._children.add(self) | |
class IndexedtrieKey(object): | |
""" | |
A pair (index, key) acting as an indexedtrie's key | |
""" | |
__slots__ = "index", "key" | |
def __init__(self, index, key): | |
self.index = index | |
self.key = key | |
def __repr__(self): | |
return "(%d, %s)" % (self.index, self.key) | |
class indexedtrie(object): | |
""" | |
The indexed trie data-structure. | |
""" | |
__slots__ = "_children", "_indexes", "_values", "_nodescount", "_default_factory" | |
def __init__(self, items=None, default_factory=_SENTINEL): | |
""" | |
A list of items can be passed to initialize the indexed trie. | |
""" | |
self._children = set() | |
self.setdefault(default_factory) | |
self._indexes = [] | |
self._values = [] | |
self._nodescount = 0 # keeping track of nodes count is purely informational | |
if items is not None: | |
for k, v in items: | |
if isinstance(k, IndexedtrieKey): | |
self.__setitem__(k.key, v) | |
else: | |
self.__setitem__(k, v) | |
@classmethod | |
def fromkeys(cls, keys, value=_SENTINEL, default_factory=_SENTINEL): | |
""" | |
Build a new indexedtrie from a list of keys. | |
""" | |
obj = cls(default_factory=default_factory) | |
for key in keys: | |
if value is _SENTINEL: | |
if default_factory is not _SENTINEL: | |
obj[key] = obj._default_factory() | |
else: | |
obj[key] = None | |
else: | |
obj[key] = value | |
return obj | |
@classmethod | |
def fromsplit(cls, keys, value=_SENTINEL, default_factory=_SENTINEL): | |
""" | |
Build a new indexedtrie from a splitable object. | |
""" | |
obj = cls(default_factory=default_factory) | |
for key in keys.split(): | |
if value is _SENTINEL: | |
if default_factory is not _SENTINEL: | |
obj[key] = obj._default_factory() | |
else: | |
obj[key] = None | |
else: | |
obj[key] = value | |
return obj | |
def setdefault(self, factory=_SENTINEL): | |
""" | |
""" | |
if factory is not _SENTINEL: | |
# indexed trie will act like a collections.defaultdict except in some cases because the __missing__ | |
# method is not implemented here (on purpose). | |
# That means that simple lookups on a non existing key will return a default value without adding | |
# the key, which is the more logical way to do. | |
# Also means that if your default_factory is for example "list", you won't be able to create new | |
# items with "append" or "extend" methods which are updating the list itself. | |
# Instead you have to do something like trie["newkey"] += [...] | |
try: | |
_ = factory() | |
except TypeError: | |
# a default value is also accepted as default_factory, even "None" | |
self._default_factory = lambda: factory | |
else: | |
self._default_factory = factory | |
else: | |
self._default_factory = _SENTINEL | |
def copy(self): | |
""" | |
Return a pseudo-shallow copy of the indexedtrie. | |
Keys and nodes are deepcopied, but if you store some referenced objects in values, only the references will be copied. | |
""" | |
return self.__class__(self.items(), default_factory=self._default_factory) | |
def __len__(self): | |
return len(self._indexes) | |
def __repr__(self): | |
if self._default_factory is not _SENTINEL: | |
default = ", default_value=%s" % self._default_factory() | |
else: | |
default = "" | |
return "<%s object at %s: %d items, %d nodes%s>" % (self.__class__.__name__, hex(id(self)), len(self), self._nodescount, default) | |
def __str__(self): | |
ret = ["%s: %s" % (k, v) for k, v in self.items()] | |
return "{%s}" % ", ".join(ret) | |
def __iter__(self): | |
return self.keys() | |
def __contains__(self, key_or_index): | |
""" | |
Return True if the key or index exists in the indexed trie. | |
""" | |
if isinstance(key_or_index, IndexedtrieKey): | |
return key_or_index.index >= 0 and key_or_index.index < len(self) | |
if isinstance(key_or_index, int): | |
return key_or_index >= 0 and key_or_index < len(self) | |
if self._seems_valid_key(key_or_index): | |
try: | |
node = self._get_node(key_or_index) | |
except KeyError: | |
return False | |
else: | |
return node._index is not None | |
raise TypeError("invalid key type") | |
def __getitem__(self, key_or_index): | |
""" | |
""" | |
if isinstance(key_or_index, IndexedtrieKey): | |
return self._values[key_or_index.index] | |
if isinstance(key_or_index, int) or isinstance(key_or_index, slice): | |
return self._values[key_or_index] | |
if self._seems_valid_key(key_or_index): | |
try: | |
node = self._get_node(key_or_index) | |
except KeyError: | |
if self._default_factory is _SENTINEL: | |
raise | |
else: | |
return self._default_factory() | |
else: | |
if node._index is None: | |
if self._default_factory is _SENTINEL: | |
raise KeyError | |
else: | |
return self._default_factory() | |
else: | |
return self._values[node._index] | |
raise TypeError("invalid key type") | |
def __setitem__(self, key_or_index, value): | |
""" | |
""" | |
if isinstance(key_or_index, IndexedtrieKey): | |
self._values[key_or_index.index] = value | |
elif isinstance(key_or_index, int): | |
self._values[key_or_index] = value | |
elif isinstance(key_or_index, slice): | |
raise NotImplementedError | |
elif self._seems_valid_key(key_or_index): | |
try: | |
node = self._get_node(key_or_index) | |
except KeyError: | |
# create a new node | |
self._add_node(key_or_index, value) | |
else: | |
if node._index is None: | |
# if node exists but not indexed, we index it and update the value | |
self._add_to_index(node, value) | |
else: | |
# else we update its value | |
self._values[node._index] = value | |
else: | |
raise TypeError("invalid key type") | |
def __delitem__(self, key_or_index): | |
""" | |
""" | |
if isinstance(key_or_index, IndexedtrieKey): | |
node = self._indexes[key_or_index.index] | |
elif isinstance(key_or_index, int): | |
node = self._indexes[key_or_index] | |
elif isinstance(key_or_index, slice): | |
raise NotImplementedError | |
elif self._seems_valid_key(key_or_index): | |
node = self._get_node(key_or_index) | |
if node._index is None: | |
raise KeyError | |
else: | |
raise TypeError("invalid key type") | |
# switch last index with deleted index (except if deleted index is last index) | |
last_node, last_value = self._indexes.pop(), self._values.pop() | |
if node._index != last_node._index: | |
last_node._index = node._index | |
self._indexes[node._index] = last_node | |
self._values[node._index] = last_value | |
if len(node._children) > 1: | |
#case 1: node has more than 1 child, only turn index off | |
node._index = None | |
elif len(node._children) == 1: | |
# case 2: node has 1 child | |
child = node._children.pop() | |
child._key = node._key + child._key | |
child._parent = node._parent | |
node._parent._children.add(child) | |
node._parent._children.remove(node) | |
del(node) | |
self._nodescount -= 1 | |
else: | |
# case 3: node has no child, check the parent node | |
parent = node._parent | |
parent._children.remove(node) | |
del(node) | |
self._nodescount -= 1 | |
if hasattr(parent, "_index"): | |
if parent._index is None and len(parent._children) == 1: | |
node = parent._children.pop() | |
node._key = parent._key + node._key | |
node._parent = parent._parent | |
parent._parent._children.add(node) | |
parent._parent._children.remove(parent) | |
del(parent) | |
self._nodescount -= 1 | |
@staticmethod | |
def _seems_valid_key(key): | |
""" | |
Return True if "key" can be a valid key (must be subscriptable). | |
""" | |
try: | |
_ = key[:0] | |
except TypeError: | |
return False | |
return True | |
def keys(self, prefix=None): | |
""" | |
Yield keys stored in the indexedtrie where key is a IndexedtrieKey object. | |
If prefix is given, yield only keys of items with key matching the prefix. | |
""" | |
if prefix is None: | |
for i, node in enumerate(self._indexes): | |
yield IndexedtrieKey(i, self._get_key(node)) | |
else: | |
if self._seems_valid_key(prefix): | |
empty = prefix[:0] | |
children = [(empty, prefix, child) for child in self._children] | |
while len(children): | |
_children = [] | |
for key, prefix, child in children: | |
if prefix == child._key[:len(prefix)]: | |
_key = key + child._key | |
_children.extend([(_key, empty, _child) for _child in child._children]) | |
if child._index is not None: | |
yield IndexedtrieKey(child._index, _key) | |
elif prefix[:len(child._key)] == child._key: | |
_prefix = prefix[len(child._key):] | |
_key = key + prefix[:len(child._key)] | |
_children.extend([(_key, _prefix, _child) for _child in child._children]) | |
children = _children | |
else: | |
raise ValueError("invalid prefix type") | |
def values(self, prefix=None): | |
""" | |
Yield values stored in the indexedtrie. | |
If prefix is given, yield only values of items with key matching the prefix. | |
""" | |
if prefix is None: | |
for value in self._values: | |
yield value | |
else: | |
for key in self.keys(prefix): | |
yield self._values[key.index] | |
def items(self, prefix=None): | |
""" | |
Yield (key, value) pairs stored in the indexedtrie where key is a IndexedtrieKey object. | |
If prefix is given, yield only (key, value) pairs of items with key matching the prefix. | |
""" | |
for key in self.keys(prefix): | |
yield key, self._values[key.index] | |
def show_tree(self, node=None, level=0): | |
""" | |
Pretty print the internal trie (recursive function). | |
""" | |
if node is None: | |
node = self | |
for child in node._children: | |
print("-" * level + "<key=%s, index=%s>" % (child._key, child._index)) | |
if len(child._children): | |
self.show_tree(child, level + 1) | |
def _get_node(self, key): | |
""" | |
Return the node associated to key or raise a KeyError. | |
""" | |
children = self._children | |
while len(children): | |
notfound = True | |
for child in children: | |
if key == child._key: | |
return child | |
if child._key == key[:len(child._key)]: | |
children = child._children | |
key = key[len(child._key):] | |
notfound = False | |
break | |
if notfound: | |
break | |
raise KeyError | |
def _add_node(self, key, value): | |
""" | |
Add a new key in the trie and updates indexes and values. | |
""" | |
children = self._children | |
parent = self | |
moved = None | |
done = len(children) == 0 | |
# we want to insert key="abc" | |
while not done: | |
done = True | |
for child in children: | |
# assert child._key != key # uncomment if you don't trust me | |
if child._key == key[:len(child._key)]: | |
# case 1: child's key is "ab", insert "c" in child's children | |
parent = child | |
children = child._children | |
key = key[len(child._key):] | |
done = len(children) == 0 | |
break | |
elif key == child._key[:len(key)]: | |
# case 2: child's key is "abcd", we insert "abc" in place of the child | |
# child's parent will be the inserted node and child's key is now "d" | |
parent = child._parent | |
moved = child | |
parent._children.remove(moved) | |
moved._key = moved._key[len(key):] | |
break | |
elif type(key) is type(child._key): # don't mess it up | |
# find longest common prefix | |
prefix = key[:0] | |
for i, c in enumerate(key): | |
if child._key[i] != c: | |
prefix = key[:i] | |
break | |
if prefix: | |
# case 3: child's key is abd, we spawn a new node with key "ab" | |
# to replace child ; child's key is now "d" and child's parent is | |
# the new created node. | |
# the new node will also be inserted as a child of this node | |
# with key "c" | |
node = _Node(prefix, child._parent) | |
self._nodescount += 1 | |
child._parent._children.remove(child) | |
child._key = child._key[len(prefix):] | |
child._parent = node | |
node._children.add(child) | |
key = key[len(prefix):] | |
parent = node | |
break | |
# create the new node | |
node = _Node(key, parent) | |
self._nodescount += 1 | |
if moved is not None: | |
# if we have moved an existing node, update it | |
moved._parent = node | |
node._children.add(moved) | |
self._add_to_index(node, value) | |
def _get_key(self, node): | |
""" | |
Rebuild key from a terminal node. | |
""" | |
key = node._key | |
while node._parent is not self: | |
node = node._parent | |
key = node._key + key | |
return key | |
def _add_to_index(self, node, value): | |
""" | |
Add a new node to the index. | |
Also record its value. | |
""" | |
node._index = len(self) | |
self._indexes.append(node) | |
self._values.append(value) | |
def key2index(self, key): | |
""" | |
key -> index | |
""" | |
if self._seems_valid_key(key): | |
node = self._get_node(key) | |
if node._index is not None: | |
return node._index | |
raise KeyError | |
raise TypeError("invalid key type") | |
def index2key(self, index): | |
""" | |
index or IndexedtrieKey -> key. | |
""" | |
if isinstance(index, IndexedtrieKey): | |
index = index.index | |
elif not isinstance(index, int): | |
raise TypeError("index must be an int") | |
if index < 0 or index > len(self._indexes): | |
raise IndexError | |
return self._get_key(self._indexes[index]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment