Skip to content

Instantly share code, notes, and snippets.

Last active June 2, 2020 00:18
Show Gist options
  • Save fbparis/b3ddd5673b603b42c880974b23db7cda to your computer and use it in GitHub Desktop.
Save fbparis/b3ddd5673b603b42c880974b23db7cda to your computer and use it in GitHub Desktop.
A replacement for the indexed trie data-structure with a very low memory footprint!
kik (key to index to key) is a new implentation of indexed trie ( reducing memory footprint by a factor 7 or more!
The main differences with indexed tries are:
* Trie nodes are no longer stored in Node objects but in a python list. An internal mechanism use the list to simulate a tree behaviour, without any recursive function.
* Nodes and values stored in the kik can be compressed with several compression levels:
* None: no compression at all
* 0: data are stored as pickle.dumps (memory use divided by 3-4)
* -1, 1 to 9: pickle.dumps are compressed with zlib
* Some internal methods can be cached with a custom implentation of a LRU cache.
import pickle
import zlib
import logging
from collections import OrderedDict, namedtuple
from functools import wraps
from smart_open import smart_open
__author__ = "@fbparis"
format='%(asctime)s %(name)s %(levelname)s %(message)s'
logger = logging.getLogger(__name__)
_SENTINEL = object()
CacheInfo = namedtuple("CacheInfo", "hits misses maxsize currsize")
def lru_cache(cache):
A replacement for functools.lru_cache() build on a custom LRU Class.
It can cache class methods.
def decorator(func):
logger.debug("assigning cache %r to function %s" % (cache, func.__name__))
def wrapped_func(*args, **kwargs):
ret = cache[args]
logger.debug("cached value returned for function %s" % func.__name__)
return ret
except KeyError:
ret = func(*args, **kwargs)
logger.debug("cache updated for function %s" % func.__name__)
cache[args] = ret
return ret
return wrapped_func
return decorator
class LRU(OrderedDict):
Custom implementation of a LRU cache, build on top of an Ordered dict.
__slots__ = "_hits", "_misses", "_maxsize"
def __new__(cls, maxsize=128):
if maxsize is None:
return None
return super().__new__(cls, maxsize=maxsize)
def __init__(self, maxsize=128, *args, **kwargs):
self.maxsize = maxsize
self._hits = 0
self._misses = 0
super().__init__(*args, **kwargs)
def __getitem__(self, key):
value = super().__getitem__(key)
except KeyError:
self._misses += 1
self._hits += 1
return value
def __setitem__(self, key, value):
super().__setitem__(key, value)
if len(self) > self._maxsize:
oldest, = next(iter(self))
del self[oldest]
def __delitem__(self, key):
except KeyError:
def __repr__(self):
return "<%s object at %s: %s>" % (self.__class__.__name__, hex(id(self)), self.cache_info())
def cache_info(self):
return CacheInfo(self._hits, self._misses, self._maxsize, len(self))
def clear(self):
self._hits, self._misses = 0, 0
def maxsize(self):
return self._maxsize
def maxsize(self, maxsize):
if not isinstance(maxsize, int):
raise TypeError
elif maxsize < 2:
raise ValueError
elif maxsize & (maxsize - 1) != 0:
logger.warning("LRU feature performs best when maxsize is a power-of-two, maybe.")
while maxsize < len(self):
oldest, = next(iter(self))
del self[oldest]
self._maxsize = maxsize
class Node(object):
A single node in the trie.
__slots__ = "parent", "index", "children", "key"
def __init__(self, parent, index, children, key):
self.parent = parent
self.index = index
self.children = set(children)
self.key = key
def __eq__(self, other):
for key in self.__slots__:
if getattr(self, key) != getattr(other, key):
return False
return True
def __repr__(self):
return "<%s object at %s: parent=%s, index=%s, children=%s, key=%s>" % (self.__class__.__name__, hex(id(self)), self.parent, self.index, self.children, self.key)
class kik_key(object):
A pair (key, index) representing a key in the kik object.
__slots__ = "index", "key"
def __init__(self, index, key):
self.index = index
self.key = key
def __repr__(self):
return "(%d, %s)" % (self.index, self.key)
def __eq__(self, other):
for key in self.__slots__:
if getattr(self, key) != getattr(other, key):
return False
return True
class kik(object):
key-to-index-to-key object.
key is any subscritable and the matching index is an autoincrement integer.
values can be any objects.
__slots__ = "_config", "_indexes", "_values", "_nodes", "_cache", "_root", "__dict__"
default_config = {
"compress_nodes": 0,
"compress_values": None,
"cache_keys": None,
"cache_nodes": 128,
"cache_indexes": 128,
"cache_values": None
def __init__(self, *args, **kwargs):
if len(args) > 1:
raise ValueError
self._config = self.default_config
if args and isinstance(args[0], self.__class__):
self._indexes = []
self._values = []
self._nodes = []
self._root = set()
if args:
items = args[0].items()
except AttributeError:
items = args[0]
for k, v in items:
if isinstance(k, kik_key):
self.__setitem__(k.key, v)
self.__setitem__(k, v)
def __getstate__(self):
Prevent the caches to be pickled
return self._config, self._indexes, self._values, self._nodes, self._root
def __setstate__(self, state):
Reset the caches on unpickling
self._config, self._indexes, self._values, self._nodes, self._root = state
def __eq__(self, other):
for key in "_nodes", "_values":
if getattr(self, key) != getattr(other, key):
return False
return True
def __len__(self):
return len(self._indexes)
def __repr__(self):
return "<%s object at %s: len=%d, nodes=%d, config=%s>" % (self.__class__.__name__, hex(id(self)), len(self._indexes), len(self._nodes), self._config)
def __str__(self):
ret = ["%s: %s" % (k, v) for k, v in self.items()]
return "{%s}" % ", ".join(ret)
def __iter__(self):
return self.keys()
def __contains__(self, key_or_index):
Return True if key or index exists in the kik.
if isinstance(key_or_index, kik_key):
return key_or_index.index >= 0 and key_or_index.index < len(self)
if isinstance(key_or_index, int):
return key_or_index >= 0 and key_or_index < len(self)
if self._seems_valid_key(key_or_index):
node, _ = self._cache_keys(key_or_index)
except KeyError:
return False
return node.index is not None
def __getitem__(self, key_or_index):
if isinstance(key_or_index, kik_key):
return self._cache_values(key_or_index.index)
if isinstance(key_or_index, int):
return self._cache_values(key_or_index)
if isinstance(key_or_index, slice):
return [self._cache_values(i) for i in range(*key_or_index.indices(len(self)))]
if self._seems_valid_key(key_or_index):
node, _ = self._cache_keys(key_or_index)
except KeyError:
if self.default_factory is _SENTINEL:
return self.default_factory()
if node.index is None:
if self.default_factory is _SENTINEL:
raise KeyError
return self.default_factory()
return self._cache_values(node.index)
def __setitem__(self, key_or_index, value):
if isinstance(key_or_index, kik_key):
self._update_values(value, key_or_index.index)
elif isinstance(key_or_index, int):
self._update_values(value, key_or_index)
elif isinstance(key_or_index, slice):
raise NotImplementedError
elif self._seems_valid_key(key_or_index):
node, node_index = self._cache_keys(key_or_index)
except KeyError:
# create a new node
self._add_node(key_or_index, value)
if node.index is None:
# update node index and add value to _values
self._activate_node(node, node_index, value)
# update _values
self._update_values(value, node.index)
raise TypeError
def __delitem__(self, key_or_index):
if isinstance(key_or_index, kik_key):
node_index = self._indexes[key_or_index.index]
node = self._cache_nodes(node_index)
elif isinstance(key_or_index, int):
node_index = self._indexes[key_or_index]
node = self._cache_nodes(node_index)
elif isinstance(key_or_index, slice):
raise NotImplementedError
elif self._seems_valid_key(key_or_index):
node, node_index = self._cache_keys(key_or_index)
if node.index is None:
raise KeyError
raise TypeError
# clear cache entries
if "cache_values" in self._cache:
del self._cache["cache_values"][node.index]
# switch last indexes and values items with the deleted one
last_index, last_value = self._indexes.pop(), self._values.pop()
if node.index != len(self):
# update the trie
last_node = self._cache_nodes(last_index)
if "cache_indexes" in self._cache:
del self._cache["cache_indexes"][last_index]
if "cache_values" in self._cache:
del self._cache["cache_values"][last_node.index]
last_node.index = node.index
self._indexes[node.index] = last_index
self._values[node.index] = last_value
self._update_nodes(last_node, last_index)
# update the internal radix trie
assert node.index is not None, "node.index for index %d should not be None: %r" % (node_index, node)
if len(node.children) > 1:
# case 1: node has more than 1 child, only turn index off
node.index = None
self._update_nodes(node, node_index)
# cases 2 and 3 are passed to _remove_node
_ = self._remove_node(node, node_index)
def fromkeys(cls, keys, value=_SENTINEL, **kwargs):
obj = cls(**kwargs)
for key in keys:
if value is _SENTINEL:
if obj.default_factory is not _SENTINEL:
obj[key] = obj.default_factory()
obj[key] = None
obj[key] = value
return obj
def fromsplit(cls, keys, value=_SENTINEL, **kwargs):
return cls.fromkeys(keys.split(), value, **kwargs)
def load(cls, name):
with smart_open("%s.%s.bz2" % (name, cls.__name__), "rb") as f:
obj = pickle.load(f)
assert isinstance(obj, cls)
return obj
def save(self, name):
with smart_open("%s.%s.bz2" % (name, self.__class__.__name__), "wb") as f:
pickle.dump(self, f)
def copy(self):
Return a shallow of the kik
return self.__class__(self)
def keys(self, prefix=None):
if prefix is None:
for i, node_index in enumerate(self._indexes):
yield kik_key(i, self._nocache_indexes(node_index))
if self._seems_valid_key(prefix):
empty = prefix[:0]
children = [(empty, prefix, child_index) for child_index in self._root]
while len(children):
_children = []
for key, prefix, child_index in children:
child = self._cache_nodes(child_index)
if prefix == child.key[:len(prefix)]:
_key = key + child.key
_children.extend([(_key, empty, _child_index) for _child_index in child.children])
if child.index is not None:
yield kik_key(child.index, _key)
elif prefix[:len(child.key)] == child.key:
_prefix = prefix[len(child.key):]
_key = key + prefix[:len(child.key)]
_children.extend([(_key, _prefix, _child_index) for _child_index in child.children])
children = _children
raise ValueError
def values(self, prefix=None):
if prefix is None:
for i in range(len(self._values)):
yield self._nocache_values(i)
for key in self.keys(prefix):
yield self._nocache_values(key.index)
def items(self, prefix=None):
for key in self.keys(prefix):
yield key, self._nocache_values(key.index)
def show_tree(self, children=None, level=0):
Pretty print the internal trie (recursive function).
if children is None:
children = self._root
for child_index in children:
child = self._nocache_nodes(child_index)
print("-" * level + "<key=%s, index=%s>" % (child.key, child.index))
self.show_tree(child.children, level + 1)
def cache_info(self):
if self._cache is not None:
for name, cache in self._cache.items():
print("%s: %r" % (name, cache))
def default_factory(self):
return self._config["default_factory"]
except KeyError:
return _SENTINEL
def set_default_factory(self, factory=_SENTINEL):
if factory is not _SENTINEL:
# kik will act like a collections.defaultdict except in some cases because the __missing__
# method is not implemented here (on purpose).
# That means that simple lookups on a non existing key will return a default value without adding
# the key, which is the more logical way to do.
# Also means that if your default_factory is for example "list", you won't be able to create new
# items with "append" or "extend" methods which are updating the list itself.
# Instead you have to do something like kik_object["newkey"] += [...]
_ = factory()
except TypeError:
# a default value is also accepted as default_factory, even "None"
self._config["default_factory"] = lambda: factory
self._config["default_factory"] = factory
del self._config["default_factory"]
except KeyError:
def cache(self):
return self._cache
def get_cache(self, key):
return self._cache["cache_" + key]
raise ValueError
def set_cache(self, key, value):
key = "cache_" + key
return self._set_cache(key, value)
def remove_cache(self, key):
self.set_cache(key, None)
def compress_level(self, key):
return self._config["compress_" + key]
except KeyError:
raise ValueError
def set_compress_level(self, key, value):
current = self.compress_level(key)
if current != value:
for i, x in enumerate(getattr(self, "_" + key)):
if current is None:
v = x
elif current == 0:
v = pickle.loads(x)
v = pickle.loads(zlib.decompress(x))
if value is None:
x = v
elif value == 0:
x = pickle.dumps(v)
x = zlib.compress(pickle.dumps(v), value)
getattr(self, "_" + key)[i] = x
self._config["compress_" + key] = value
def key_to_index(self, key):
node = self._cache_keys(key)
return node.index
def index_to_key(self, index):
return self._cache_indexes(self._indexes[index])
def key_from_index(self, index):
return self.index_to_key(index)
def index_from_key(self, key):
return self.key_to_index(key)
def check_integrity(self):"checking integrity...")
# assert len([1 for k, v in self.items() if k.key != v]) == 0, "key do not match value"
assert len([1 for k, v in self.items() if self._nocache_nodes(self._indexes[k.index]).index != k.index]) == 0, "index do not match node.index"
for node_index in self._root:
assert node_index < len(self._nodes), "root's child index out of range: %d" % node_index
node = self._nocache_nodes(node_index)
assert node.parent == None, "root's child has bad parent: %r" % node
for node_index in range(len(self._nodes)):
node = self._nocache_nodes(node_index)
if node.index is not None:
assert node.index < len(self), "node.index out of range: %r" % node
assert self._indexes[node.index] == node_index, "index do not match node_index: %r" % node
if node.parent is None:
assert node_index in self._root, "node is not in root: %r" % node
assert node.parent < len(self._nodes), "node.parent out of range: %r" % node
parent = self._nocache_nodes(node.parent)
assert node_index in parent.children, "node_index not in parent.children (parent=%r, node=%r)" % (parent, node)
for child_index in node.children:
assert child_index < len(self._nodes), "node's child_index out of range: %r" % node
child = self._nocache_nodes(child_index)
assert child.parent == node_index, "node's child's parent do not match node_index (node=%r, child=%r)" % (node, child)
# check caches
for k in self.keys():
_ = self.index_to_key(k.index)
_ = self.key_to_index(k.key)
for cacheid, cache in self._cache.items():
nocache = getattr(self, '_no' + cacheid)
for key, value in cache.items():
real_value = nocache(*key)
except Exception as real_value:
raise AssertionError("corrupted cache for %s and key %s (cache=%s vs real=%s)" % (cacheid, key, value, real_value))
assert value == real_value, "corrupted cache for %s and key %s (cache=%s vs real=%s)" % (cacheid, key, value, real_value)
return True
def _seems_valid_key(key):
_ = key[:0]
except TypeError:
return False
return True
def _set_cache(self, key, value):
if key not in self._config:
raise ValueError
if key in self._cache:
assert isinstance(self._cache[key], LRU), "invalid cache type: %r" % self._cache[key]
if value is None:
# remove a cache
func = "_" + key
del self._cache[key]
setattr(self, func, getattr(self, func).__wrapped__)
# update cache maxsize
self._cache[key].maxsize = value
# create a new cache and decorate appropriate function
cache = LRU(value)
if isinstance(cache, LRU):
func = "_" + key
self._cache[key] = cache
setattr(self, func, lru_cache(self._cache[key])(getattr(self, func)))
self._config[key] = value
def _nocache_keys(self, key):
if len(self._nodes) == 0:
raise KeyError
children = self._root
while len(children):
notfound = True
for child_index in children:
child = self._nocache_nodes(child_index)
if key == child.key:
return child, child_index
if child.key == key[:len(child.key)]:
children = child.children
key = key[len(child.key):]
notfound = False
if notfound:
raise KeyError
def _cache_keys(self, key):
if len(self._nodes) == 0:
raise KeyError
children = self._root
while len(children):
notfound = True
for child_index in children:
child = self._cache_nodes(child_index)
if key == child.key:
return child, child_index
if child.key == key[:len(child.key)]:
children = child.children
key = key[len(child.key):]
notfound = False
if notfound:
raise KeyError
def _nocache_indexes(self, node_index):
node = self._nocache_nodes(node_index)
key = node.key
while node.parent is not None:
node = self._nocache_nodes(node.parent)
key = node.key + key
return key
def _cache_indexes(self, node_index):
node = self._cache_nodes(node_index)
key = node.key
while node.parent is not None:
node = self._cache_nodes(node.parent)
key = node.key + key
return key
def _nocache_nodes(self, node_index):
if "compress_nodes" in self._config and self._config["compress_nodes"] is not None:
if self._config["compress_nodes"] == 0:
return Node(*pickle.loads(self._nodes[node_index]))
return Node(*pickle.loads(zlib.decompress(self._nodes[node_index])))
return Node(*self._nodes[node_index])
def _cache_nodes(self, node_index):
return self._nocache_nodes(node_index)
def _nocache_values(self, index):
if "compress_values" in self._config and self._config["compress_values"] is not None:
if self._config["compress_values"] == 0:
return pickle.loads(self._values[index])
return pickle.loads(zlib.decompress(self._values[index]))
return self._values[index]
def _cache_values(self, index):
return self._nocache_values(index)
def _create_node(self, key, parent, parent_index, children=set(), index=None):
node = Node(parent_index, index, children, key)
node_index = len(self._nodes)
if parent_index is None:
self._update_nodes(parent, parent_index)
return node, node_index
def _add_node(self, key, value):
parent, parent_index = None, None
children = self._root
full_key = key
moved = None
done = len(children) == 0
while not done:
done = True
for child_index in children:
child = self._cache_nodes(child_index)
assert child.key != key
if child.key == key[:len(child.key)]:
parent, parent_index = child, child_index
children = child.children
key = key[len(child.key):]
done = len(children) == 0
elif key == child.key[:len(key)]:
moved, moved_index = child, child_index
if parent is None:
self._update_nodes(parent, parent_index)
moved.key = moved.key[len(key):]
elif type(key) is type(child.key):
prefix = key[:0]
for i, c in enumerate(key):
if child.key[i] != c:
prefix = key[:i]
if prefix:
if parent is None:
node, node_index = self._create_node(prefix, parent, parent_index, {child_index})
child.key = child.key[len(prefix):]
child.parent = node_index
self._update_nodes(child, child_index)
key = key[len(prefix):]
parent, parent_index = node, node_index
if moved is not None:
children = {moved_index}
children = set()
node, node_index = self._create_node(key, parent, parent_index, children, len(self))
if moved is not None:
moved.parent = node_index
self._update_nodes(moved, moved_index)
def _activate_node(self, node, node_index, value, key=None):
node.index = len(self)
self._update_nodes(node, node_index)
def _update_nodes(self, node, node_index):
self._nodes[node_index] = self._format_node(node)
# just clear the cache(s)
if "cache_nodes" in self._cache:
del self._cache["cache_nodes"][node_index]
if "cache_indexes" in self._cache:
del self._cache["cache_indexes"][node_index]
def _remove_node(self, node, node_index):
assert len(node.children) <= 1, "node %d should not be removed: %r" % (node_index, node)
def remove_node(node, node_index):
assert len(node.children) <= 1, "node %d has more than 1 child: %r" % (node_index, node)
# clear cache(s)
if "cache_nodes" in self._cache:
del self._cache["cache_nodes"][node_index]
if "cache_indexes" in self._cache:
del self._cache["cache_indexes"][node_index]
last_node_index = len(self._nodes) - 1
last_node = self._nodes.pop()
if last_node_index != node_index:
# also clear cache(s) for the last node
if "cache_nodes" in self._cache:
del self._cache["cache_nodes"][last_node_index]
if "cache_indexes" in self._cache:
del self._cache["cache_indexes"][last_node_index]
# replace nodes, last_node will now refer previous last_node...
self._nodes[node_index] = last_node
last_node = self._nocache_nodes(node_index)
# remove node from its parent's children
if node.parent is None:
if node.parent != last_node_index:
parent = self._cache_nodes(node.parent)
# special case if parent was the last node we switch with...
node.parent = node_index
parent = last_node
# if node has a child, node's child is now parent's child
# also update child's key
if len(node.children) == 1:
child_index = node.children.pop()
# special case if the child is the last_node
if child_index == last_node_index:
child = last_node
child = self._cache_nodes(child_index)
child.key = node.key + child.key
child.parent = node.parent
if node.parent is None:
# if child was the last_node, no need to update it twice
if child_index != last_node_index:
self._update_nodes(child, child_index) # cache(s) will be updated
if node.parent is not None:
# if parent was the last_node, no need to update it twice
if node.parent != node_index:
self._update_nodes(parent, node.parent)
# if the last node was not the removed node, we need to fix it
if last_node_index != node_index:
# cache(s) have already been cleared
# fix reversed index if needed
if last_node.index is not None:
self._indexes[last_node.index] = node_index
# update parent's children...
if last_node.parent is None:
parent = self._cache_nodes(last_node.parent)
self._update_nodes(parent, last_node.parent) # cache(s) will be updated
# ...and children's parent
for child_index in last_node.children:
child = self._cache_nodes(child_index)
child.parent = node_index
self._update_nodes(child, child_index) # cache(s) will be updated
self._update_nodes(last_node, node_index) # cache(s) will be updated (twice?)
# will return the real parent index
return node.parent
if len(node.children) == 0:
parent_index = remove_node(node, node_index)
if parent_index is not None:
parent = self._cache_nodes(parent_index)
if parent.index is None and len(parent.children) == 1:
self._remove_node(parent, parent_index)
# simply update child and parent and remove node
_ = remove_node(node, node_index)
def _update_values(self, value, index):
update values and cache if needed
self._values[index] = self._format_value(value)
if "cache_values" in self._cache:
# better del cache than update it beacause value can be a reference to an object
del self._cache["cache_values"][index]
def _format_value(self, value):
if "compress_values" in self._config and self._config["compress_values"] is not None:
if self._config["compress_values"] == 0:
return pickle.dumps(value)
return zlib.compress(pickle.dumps(value), self._config["compress_values"])
return value
def _format_node(self, node):
if "compress_nodes" in self._config and self._config["compress_nodes"] is not None:
if self._config["compress_nodes"] == 0:
return pickle.dumps((node.parent, node.index, list(node.children), node.key))
return zlib.compress(pickle.dumps((node.parent, node.index, list(node.children), node.key)), self._config["compress_nodes"])
return (node.parent, node.index, list(node.children), node.key)
def _setup(self):
Parse and update config with the default values of the class
# the cache_keys option has been deprecated because it's too complex to implement
# and requires so many operations it probably won't be helpful.
del self._config["cache_keys"]
self._cache = dict()
for k, v in self._config.items():
if k.startswith("cache_") and v is not None:
self._set_cache(k, v)
if __name__ == '__main__':
# Run some tests
from random import choice, randint
def run_test(obj):"running tests... (CTRL-C to stop)")
abc = "abcdef"
while 1:
x = "".join([choice(abc) for _ in range(randint(1, 8))])"adding %s" % x)
obj[x] = x
if randint(0, len(obj)) > 0:
i = randint(0, len(obj) - 1)"removing k[%d] (%s)" % (i, obj[i]))
del obj[i]
except AssertionError as e:
except KeyboardInterrupt:"success!")
print("%r" % obj)
k = kik()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment