Last active
June 2, 2020 00:18
-
-
Save fbparis/b3ddd5673b603b42c880974b23db7cda to your computer and use it in GitHub Desktop.
A replacement for the indexed trie data-structure with a very low memory footprint!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
kik (key to index to key) is a new implentation of indexed trie (https://gist.github.com/fbparis/e114958a4a9be84146a1a579cf677e8d) reducing memory footprint by a factor 7 or more! | |
The main differences with indexed tries are: | |
* Trie nodes are no longer stored in Node objects but in a python list. An internal mechanism use the list to simulate a tree behaviour, without any recursive function. | |
* Nodes and values stored in the kik can be compressed with several compression levels: | |
* None: no compression at all | |
* 0: data are stored as pickle.dumps (memory use divided by 3-4) | |
* -1, 1 to 9: pickle.dumps are compressed with zlib | |
* Some internal methods can be cached with a custom implentation of a LRU cache. | |
""" | |
import pickle | |
import zlib | |
import logging | |
from collections import OrderedDict, namedtuple | |
from functools import wraps | |
from smart_open import smart_open | |
__author__ = "@fbparis" | |
logging.basicConfig( | |
level=logging.WARNING, | |
format='%(asctime)s %(name)s %(levelname)s %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
_SENTINEL = object() | |
CacheInfo = namedtuple("CacheInfo", "hits misses maxsize currsize") | |
def lru_cache(cache): | |
""" | |
A replacement for functools.lru_cache() build on a custom LRU Class. | |
It can cache class methods. | |
""" | |
def decorator(func): | |
logger.debug("assigning cache %r to function %s" % (cache, func.__name__)) | |
@wraps(func) | |
def wrapped_func(*args, **kwargs): | |
try: | |
ret = cache[args] | |
logger.debug("cached value returned for function %s" % func.__name__) | |
return ret | |
except KeyError: | |
try: | |
ret = func(*args, **kwargs) | |
except: | |
raise | |
else: | |
logger.debug("cache updated for function %s" % func.__name__) | |
cache[args] = ret | |
return ret | |
return wrapped_func | |
return decorator | |
class LRU(OrderedDict): | |
""" | |
Custom implementation of a LRU cache, build on top of an Ordered dict. | |
""" | |
__slots__ = "_hits", "_misses", "_maxsize" | |
def __new__(cls, maxsize=128): | |
if maxsize is None: | |
return None | |
return super().__new__(cls, maxsize=maxsize) | |
def __init__(self, maxsize=128, *args, **kwargs): | |
self.maxsize = maxsize | |
self._hits = 0 | |
self._misses = 0 | |
super().__init__(*args, **kwargs) | |
def __getitem__(self, key): | |
try: | |
value = super().__getitem__(key) | |
except KeyError: | |
self._misses += 1 | |
raise | |
else: | |
self.move_to_end(key) | |
self._hits += 1 | |
return value | |
def __setitem__(self, key, value): | |
super().__setitem__(key, value) | |
if len(self) > self._maxsize: | |
oldest, = next(iter(self)) | |
del self[oldest] | |
def __delitem__(self, key): | |
try: | |
super().__delitem__((key,)) | |
except KeyError: | |
pass | |
def __repr__(self): | |
return "<%s object at %s: %s>" % (self.__class__.__name__, hex(id(self)), self.cache_info()) | |
def cache_info(self): | |
return CacheInfo(self._hits, self._misses, self._maxsize, len(self)) | |
def clear(self): | |
super().clear() | |
self._hits, self._misses = 0, 0 | |
@property | |
def maxsize(self): | |
return self._maxsize | |
@maxsize.setter | |
def maxsize(self, maxsize): | |
if not isinstance(maxsize, int): | |
raise TypeError | |
elif maxsize < 2: | |
raise ValueError | |
elif maxsize & (maxsize - 1) != 0: | |
logger.warning("LRU feature performs best when maxsize is a power-of-two, maybe.") | |
while maxsize < len(self): | |
oldest, = next(iter(self)) | |
print(oldest) | |
del self[oldest] | |
self._maxsize = maxsize | |
class Node(object): | |
""" | |
A single node in the trie. | |
""" | |
__slots__ = "parent", "index", "children", "key" | |
def __init__(self, parent, index, children, key): | |
self.parent = parent | |
self.index = index | |
self.children = set(children) | |
self.key = key | |
def __eq__(self, other): | |
for key in self.__slots__: | |
if getattr(self, key) != getattr(other, key): | |
return False | |
return True | |
def __repr__(self): | |
return "<%s object at %s: parent=%s, index=%s, children=%s, key=%s>" % (self.__class__.__name__, hex(id(self)), self.parent, self.index, self.children, self.key) | |
class kik_key(object): | |
""" | |
A pair (key, index) representing a key in the kik object. | |
""" | |
__slots__ = "index", "key" | |
def __init__(self, index, key): | |
self.index = index | |
self.key = key | |
def __repr__(self): | |
return "(%d, %s)" % (self.index, self.key) | |
def __eq__(self, other): | |
for key in self.__slots__: | |
if getattr(self, key) != getattr(other, key): | |
return False | |
return True | |
class kik(object): | |
""" | |
key-to-index-to-key object. | |
key is any subscritable and the matching index is an autoincrement integer. | |
values can be any objects. | |
""" | |
__slots__ = "_config", "_indexes", "_values", "_nodes", "_cache", "_root", "__dict__" | |
default_config = { | |
"compress_nodes": 0, | |
"compress_values": None, | |
"cache_keys": None, | |
"cache_nodes": 128, | |
"cache_indexes": 128, | |
"cache_values": None | |
} | |
def __init__(self, *args, **kwargs): | |
if len(args) > 1: | |
raise ValueError | |
self._config = self.default_config | |
if args and isinstance(args[0], self.__class__): | |
self._config.update(args[0]._config) | |
self._config.update(kwargs) | |
self._indexes = [] | |
self._values = [] | |
self._nodes = [] | |
self._root = set() | |
self._setup() | |
if args: | |
try: | |
items = args[0].items() | |
except AttributeError: | |
items = args[0] | |
for k, v in items: | |
if isinstance(k, kik_key): | |
self.__setitem__(k.key, v) | |
else: | |
self.__setitem__(k, v) | |
def __getstate__(self): | |
""" | |
Prevent the caches to be pickled | |
""" | |
return self._config, self._indexes, self._values, self._nodes, self._root | |
def __setstate__(self, state): | |
""" | |
Reset the caches on unpickling | |
""" | |
self._config, self._indexes, self._values, self._nodes, self._root = state | |
self._setup() | |
def __eq__(self, other): | |
for key in "_nodes", "_values": | |
if getattr(self, key) != getattr(other, key): | |
return False | |
return True | |
def __len__(self): | |
return len(self._indexes) | |
def __repr__(self): | |
return "<%s object at %s: len=%d, nodes=%d, config=%s>" % (self.__class__.__name__, hex(id(self)), len(self._indexes), len(self._nodes), self._config) | |
def __str__(self): | |
ret = ["%s: %s" % (k, v) for k, v in self.items()] | |
return "{%s}" % ", ".join(ret) | |
def __iter__(self): | |
return self.keys() | |
def __contains__(self, key_or_index): | |
""" | |
Return True if key or index exists in the kik. | |
""" | |
if isinstance(key_or_index, kik_key): | |
return key_or_index.index >= 0 and key_or_index.index < len(self) | |
if isinstance(key_or_index, int): | |
return key_or_index >= 0 and key_or_index < len(self) | |
if self._seems_valid_key(key_or_index): | |
try: | |
node, _ = self._cache_keys(key_or_index) | |
except KeyError: | |
return False | |
else: | |
return node.index is not None | |
def __getitem__(self, key_or_index): | |
""" | |
""" | |
if isinstance(key_or_index, kik_key): | |
return self._cache_values(key_or_index.index) | |
if isinstance(key_or_index, int): | |
return self._cache_values(key_or_index) | |
if isinstance(key_or_index, slice): | |
return [self._cache_values(i) for i in range(*key_or_index.indices(len(self)))] | |
if self._seems_valid_key(key_or_index): | |
try: | |
node, _ = self._cache_keys(key_or_index) | |
except KeyError: | |
if self.default_factory is _SENTINEL: | |
raise | |
else: | |
return self.default_factory() | |
else: | |
if node.index is None: | |
if self.default_factory is _SENTINEL: | |
raise KeyError | |
else: | |
return self.default_factory() | |
else: | |
return self._cache_values(node.index) | |
def __setitem__(self, key_or_index, value): | |
""" | |
""" | |
if isinstance(key_or_index, kik_key): | |
self._update_values(value, key_or_index.index) | |
elif isinstance(key_or_index, int): | |
self._update_values(value, key_or_index) | |
elif isinstance(key_or_index, slice): | |
raise NotImplementedError | |
elif self._seems_valid_key(key_or_index): | |
try: | |
node, node_index = self._cache_keys(key_or_index) | |
except KeyError: | |
# create a new node | |
self._add_node(key_or_index, value) | |
else: | |
if node.index is None: | |
# update node index and add value to _values | |
self._activate_node(node, node_index, value) | |
else: | |
# update _values | |
self._update_values(value, node.index) | |
else: | |
raise TypeError | |
def __delitem__(self, key_or_index): | |
""" | |
""" | |
if isinstance(key_or_index, kik_key): | |
node_index = self._indexes[key_or_index.index] | |
node = self._cache_nodes(node_index) | |
elif isinstance(key_or_index, int): | |
node_index = self._indexes[key_or_index] | |
node = self._cache_nodes(node_index) | |
elif isinstance(key_or_index, slice): | |
raise NotImplementedError | |
elif self._seems_valid_key(key_or_index): | |
node, node_index = self._cache_keys(key_or_index) | |
if node.index is None: | |
raise KeyError | |
else: | |
raise TypeError | |
# clear cache entries | |
if "cache_values" in self._cache: | |
del self._cache["cache_values"][node.index] | |
# switch last indexes and values items with the deleted one | |
last_index, last_value = self._indexes.pop(), self._values.pop() | |
if node.index != len(self): | |
# update the trie | |
last_node = self._cache_nodes(last_index) | |
if "cache_indexes" in self._cache: | |
del self._cache["cache_indexes"][last_index] | |
if "cache_values" in self._cache: | |
del self._cache["cache_values"][last_node.index] | |
last_node.index = node.index | |
self._indexes[node.index] = last_index | |
self._values[node.index] = last_value | |
self._update_nodes(last_node, last_index) | |
# update the internal radix trie | |
assert node.index is not None, "node.index for index %d should not be None: %r" % (node_index, node) | |
if len(node.children) > 1: | |
# case 1: node has more than 1 child, only turn index off | |
node.index = None | |
self._update_nodes(node, node_index) | |
else: | |
# cases 2 and 3 are passed to _remove_node | |
_ = self._remove_node(node, node_index) | |
@classmethod | |
def fromkeys(cls, keys, value=_SENTINEL, **kwargs): | |
""" | |
""" | |
obj = cls(**kwargs) | |
for key in keys: | |
if value is _SENTINEL: | |
if obj.default_factory is not _SENTINEL: | |
obj[key] = obj.default_factory() | |
else: | |
obj[key] = None | |
else: | |
obj[key] = value | |
return obj | |
@classmethod | |
def fromsplit(cls, keys, value=_SENTINEL, **kwargs): | |
""" | |
""" | |
return cls.fromkeys(keys.split(), value, **kwargs) | |
@classmethod | |
def load(cls, name): | |
""" | |
""" | |
with smart_open("%s.%s.bz2" % (name, cls.__name__), "rb") as f: | |
obj = pickle.load(f) | |
assert isinstance(obj, cls) | |
return obj | |
def save(self, name): | |
""" | |
""" | |
with smart_open("%s.%s.bz2" % (name, self.__class__.__name__), "wb") as f: | |
pickle.dump(self, f) | |
def copy(self): | |
""" | |
Return a shallow of the kik | |
""" | |
return self.__class__(self) | |
def keys(self, prefix=None): | |
""" | |
""" | |
if prefix is None: | |
for i, node_index in enumerate(self._indexes): | |
yield kik_key(i, self._nocache_indexes(node_index)) | |
else: | |
if self._seems_valid_key(prefix): | |
empty = prefix[:0] | |
children = [(empty, prefix, child_index) for child_index in self._root] | |
while len(children): | |
_children = [] | |
for key, prefix, child_index in children: | |
child = self._cache_nodes(child_index) | |
if prefix == child.key[:len(prefix)]: | |
_key = key + child.key | |
_children.extend([(_key, empty, _child_index) for _child_index in child.children]) | |
if child.index is not None: | |
yield kik_key(child.index, _key) | |
elif prefix[:len(child.key)] == child.key: | |
_prefix = prefix[len(child.key):] | |
_key = key + prefix[:len(child.key)] | |
_children.extend([(_key, _prefix, _child_index) for _child_index in child.children]) | |
children = _children | |
else: | |
raise ValueError | |
def values(self, prefix=None): | |
""" | |
""" | |
if prefix is None: | |
for i in range(len(self._values)): | |
yield self._nocache_values(i) | |
else: | |
for key in self.keys(prefix): | |
yield self._nocache_values(key.index) | |
def items(self, prefix=None): | |
""" | |
""" | |
for key in self.keys(prefix): | |
yield key, self._nocache_values(key.index) | |
def show_tree(self, children=None, level=0): | |
""" | |
Pretty print the internal trie (recursive function). | |
""" | |
if children is None: | |
children = self._root | |
for child_index in children: | |
child = self._nocache_nodes(child_index) | |
print("-" * level + "<key=%s, index=%s>" % (child.key, child.index)) | |
self.show_tree(child.children, level + 1) | |
def cache_info(self): | |
""" | |
""" | |
if self._cache is not None: | |
for name, cache in self._cache.items(): | |
print("%s: %r" % (name, cache)) | |
@property | |
def default_factory(self): | |
try: | |
return self._config["default_factory"] | |
except KeyError: | |
return _SENTINEL | |
def set_default_factory(self, factory=_SENTINEL): | |
""" | |
""" | |
if factory is not _SENTINEL: | |
# kik will act like a collections.defaultdict except in some cases because the __missing__ | |
# method is not implemented here (on purpose). | |
# That means that simple lookups on a non existing key will return a default value without adding | |
# the key, which is the more logical way to do. | |
# Also means that if your default_factory is for example "list", you won't be able to create new | |
# items with "append" or "extend" methods which are updating the list itself. | |
# Instead you have to do something like kik_object["newkey"] += [...] | |
try: | |
_ = factory() | |
except TypeError: | |
# a default value is also accepted as default_factory, even "None" | |
self._config["default_factory"] = lambda: factory | |
else: | |
self._config["default_factory"] = factory | |
else: | |
try: | |
del self._config["default_factory"] | |
except KeyError: | |
pass | |
@property | |
def cache(self): | |
""" | |
""" | |
return self._cache | |
def get_cache(self, key): | |
""" | |
""" | |
try: | |
return self._cache["cache_" + key] | |
except: | |
raise ValueError | |
def set_cache(self, key, value): | |
""" | |
""" | |
key = "cache_" + key | |
return self._set_cache(key, value) | |
def remove_cache(self, key): | |
""" | |
""" | |
self.set_cache(key, None) | |
def compress_level(self, key): | |
""" | |
""" | |
try: | |
return self._config["compress_" + key] | |
except KeyError: | |
raise ValueError | |
def set_compress_level(self, key, value): | |
""" | |
""" | |
current = self.compress_level(key) | |
if current != value: | |
for i, x in enumerate(getattr(self, "_" + key)): | |
if current is None: | |
v = x | |
elif current == 0: | |
v = pickle.loads(x) | |
else: | |
v = pickle.loads(zlib.decompress(x)) | |
if value is None: | |
x = v | |
elif value == 0: | |
x = pickle.dumps(v) | |
else: | |
x = zlib.compress(pickle.dumps(v), value) | |
getattr(self, "_" + key)[i] = x | |
self._config["compress_" + key] = value | |
def key_to_index(self, key): | |
""" | |
""" | |
node = self._cache_keys(key) | |
return node.index | |
def index_to_key(self, index): | |
""" | |
""" | |
return self._cache_indexes(self._indexes[index]) | |
def key_from_index(self, index): | |
""" | |
""" | |
return self.index_to_key(index) | |
def index_from_key(self, key): | |
""" | |
""" | |
return self.key_to_index(key) | |
def check_integrity(self): | |
logger.info("checking integrity...") | |
# assert len([1 for k, v in self.items() if k.key != v]) == 0, "key do not match value" | |
assert len([1 for k, v in self.items() if self._nocache_nodes(self._indexes[k.index]).index != k.index]) == 0, "index do not match node.index" | |
for node_index in self._root: | |
assert node_index < len(self._nodes), "root's child index out of range: %d" % node_index | |
node = self._nocache_nodes(node_index) | |
assert node.parent == None, "root's child has bad parent: %r" % node | |
for node_index in range(len(self._nodes)): | |
node = self._nocache_nodes(node_index) | |
if node.index is not None: | |
assert node.index < len(self), "node.index out of range: %r" % node | |
assert self._indexes[node.index] == node_index, "index do not match node_index: %r" % node | |
if node.parent is None: | |
assert node_index in self._root, "node is not in root: %r" % node | |
else: | |
assert node.parent < len(self._nodes), "node.parent out of range: %r" % node | |
parent = self._nocache_nodes(node.parent) | |
assert node_index in parent.children, "node_index not in parent.children (parent=%r, node=%r)" % (parent, node) | |
for child_index in node.children: | |
assert child_index < len(self._nodes), "node's child_index out of range: %r" % node | |
child = self._nocache_nodes(child_index) | |
assert child.parent == node_index, "node's child's parent do not match node_index (node=%r, child=%r)" % (node, child) | |
# check caches | |
for k in self.keys(): | |
_ = self.index_to_key(k.index) | |
_ = self.key_to_index(k.key) | |
for cacheid, cache in self._cache.items(): | |
nocache = getattr(self, '_no' + cacheid) | |
for key, value in cache.items(): | |
try: | |
real_value = nocache(*key) | |
except Exception as real_value: | |
raise AssertionError("corrupted cache for %s and key %s (cache=%s vs real=%s)" % (cacheid, key, value, real_value)) | |
assert value == real_value, "corrupted cache for %s and key %s (cache=%s vs real=%s)" % (cacheid, key, value, real_value) | |
return True | |
@staticmethod | |
def _seems_valid_key(key): | |
""" | |
""" | |
try: | |
_ = key[:0] | |
except TypeError: | |
return False | |
return True | |
def _set_cache(self, key, value): | |
""" | |
""" | |
if key not in self._config: | |
raise ValueError | |
if key in self._cache: | |
assert isinstance(self._cache[key], LRU), "invalid cache type: %r" % self._cache[key] | |
if value is None: | |
# remove a cache | |
func = "_" + key | |
del self._cache[key] | |
setattr(self, func, getattr(self, func).__wrapped__) | |
else: | |
# update cache maxsize | |
self._cache[key].maxsize = value | |
else: | |
# create a new cache and decorate appropriate function | |
cache = LRU(value) | |
if isinstance(cache, LRU): | |
func = "_" + key | |
self._cache[key] = cache | |
setattr(self, func, lru_cache(self._cache[key])(getattr(self, func))) | |
self._config[key] = value | |
def _nocache_keys(self, key): | |
""" | |
""" | |
if len(self._nodes) == 0: | |
raise KeyError | |
children = self._root | |
while len(children): | |
notfound = True | |
for child_index in children: | |
child = self._nocache_nodes(child_index) | |
if key == child.key: | |
return child, child_index | |
if child.key == key[:len(child.key)]: | |
children = child.children | |
key = key[len(child.key):] | |
notfound = False | |
break | |
if notfound: | |
break | |
raise KeyError | |
def _cache_keys(self, key): | |
""" | |
""" | |
if len(self._nodes) == 0: | |
raise KeyError | |
children = self._root | |
while len(children): | |
notfound = True | |
for child_index in children: | |
child = self._cache_nodes(child_index) | |
if key == child.key: | |
return child, child_index | |
if child.key == key[:len(child.key)]: | |
children = child.children | |
key = key[len(child.key):] | |
notfound = False | |
break | |
if notfound: | |
break | |
raise KeyError | |
def _nocache_indexes(self, node_index): | |
""" | |
""" | |
node = self._nocache_nodes(node_index) | |
key = node.key | |
while node.parent is not None: | |
node = self._nocache_nodes(node.parent) | |
key = node.key + key | |
return key | |
def _cache_indexes(self, node_index): | |
""" | |
""" | |
node = self._cache_nodes(node_index) | |
key = node.key | |
while node.parent is not None: | |
node = self._cache_nodes(node.parent) | |
key = node.key + key | |
return key | |
def _nocache_nodes(self, node_index): | |
""" | |
""" | |
if "compress_nodes" in self._config and self._config["compress_nodes"] is not None: | |
if self._config["compress_nodes"] == 0: | |
return Node(*pickle.loads(self._nodes[node_index])) | |
else: | |
return Node(*pickle.loads(zlib.decompress(self._nodes[node_index]))) | |
else: | |
return Node(*self._nodes[node_index]) | |
def _cache_nodes(self, node_index): | |
""" | |
""" | |
return self._nocache_nodes(node_index) | |
def _nocache_values(self, index): | |
""" | |
""" | |
if "compress_values" in self._config and self._config["compress_values"] is not None: | |
if self._config["compress_values"] == 0: | |
return pickle.loads(self._values[index]) | |
else: | |
return pickle.loads(zlib.decompress(self._values[index])) | |
else: | |
return self._values[index] | |
def _cache_values(self, index): | |
""" | |
""" | |
return self._nocache_values(index) | |
def _create_node(self, key, parent, parent_index, children=set(), index=None): | |
""" | |
""" | |
node = Node(parent_index, index, children, key) | |
node_index = len(self._nodes) | |
self._nodes.append(self._format_node(node)) | |
if parent_index is None: | |
self._root.add(node_index) | |
else: | |
parent.children.add(node_index) | |
self._update_nodes(parent, parent_index) | |
return node, node_index | |
def _add_node(self, key, value): | |
""" | |
""" | |
parent, parent_index = None, None | |
children = self._root | |
full_key = key | |
moved = None | |
done = len(children) == 0 | |
while not done: | |
done = True | |
for child_index in children: | |
child = self._cache_nodes(child_index) | |
assert child.key != key | |
if child.key == key[:len(child.key)]: | |
parent, parent_index = child, child_index | |
children = child.children | |
key = key[len(child.key):] | |
done = len(children) == 0 | |
break | |
elif key == child.key[:len(key)]: | |
moved, moved_index = child, child_index | |
if parent is None: | |
self._root.remove(moved_index) | |
else: | |
parent.children.remove(moved_index) | |
self._update_nodes(parent, parent_index) | |
moved.key = moved.key[len(key):] | |
break | |
elif type(key) is type(child.key): | |
prefix = key[:0] | |
for i, c in enumerate(key): | |
if child.key[i] != c: | |
prefix = key[:i] | |
break | |
if prefix: | |
if parent is None: | |
self._root.remove(child_index) | |
else: | |
parent.children.remove(child_index) | |
node, node_index = self._create_node(prefix, parent, parent_index, {child_index}) | |
child.key = child.key[len(prefix):] | |
child.parent = node_index | |
self._update_nodes(child, child_index) | |
key = key[len(prefix):] | |
parent, parent_index = node, node_index | |
break | |
if moved is not None: | |
children = {moved_index} | |
else: | |
children = set() | |
node, node_index = self._create_node(key, parent, parent_index, children, len(self)) | |
self._indexes.append(node_index) | |
self._values.append(self._format_value(value)) | |
if moved is not None: | |
moved.parent = node_index | |
self._update_nodes(moved, moved_index) | |
def _activate_node(self, node, node_index, value, key=None): | |
""" | |
""" | |
node.index = len(self) | |
self._indexes.append(node_index) | |
self._values.append(self._format_value(value)) | |
self._update_nodes(node, node_index) | |
def _update_nodes(self, node, node_index): | |
""" | |
""" | |
self._nodes[node_index] = self._format_node(node) | |
# just clear the cache(s) | |
if "cache_nodes" in self._cache: | |
del self._cache["cache_nodes"][node_index] | |
if "cache_indexes" in self._cache: | |
del self._cache["cache_indexes"][node_index] | |
def _remove_node(self, node, node_index): | |
""" | |
""" | |
assert len(node.children) <= 1, "node %d should not be removed: %r" % (node_index, node) | |
def remove_node(node, node_index): | |
""" | |
""" | |
assert len(node.children) <= 1, "node %d has more than 1 child: %r" % (node_index, node) | |
# clear cache(s) | |
if "cache_nodes" in self._cache: | |
del self._cache["cache_nodes"][node_index] | |
if "cache_indexes" in self._cache: | |
del self._cache["cache_indexes"][node_index] | |
last_node_index = len(self._nodes) - 1 | |
last_node = self._nodes.pop() | |
if last_node_index != node_index: | |
# also clear cache(s) for the last node | |
if "cache_nodes" in self._cache: | |
del self._cache["cache_nodes"][last_node_index] | |
if "cache_indexes" in self._cache: | |
del self._cache["cache_indexes"][last_node_index] | |
# replace nodes, last_node will now refer previous last_node... | |
self._nodes[node_index] = last_node | |
last_node = self._nocache_nodes(node_index) | |
# remove node from its parent's children | |
if node.parent is None: | |
self._root.remove(node_index) | |
else: | |
if node.parent != last_node_index: | |
parent = self._cache_nodes(node.parent) | |
else: | |
# special case if parent was the last node we switch with... | |
node.parent = node_index | |
parent = last_node | |
parent.children.remove(node_index) | |
# if node has a child, node's child is now parent's child | |
# also update child's key | |
if len(node.children) == 1: | |
child_index = node.children.pop() | |
# special case if the child is the last_node | |
if child_index == last_node_index: | |
child = last_node | |
else: | |
child = self._cache_nodes(child_index) | |
child.key = node.key + child.key | |
child.parent = node.parent | |
if node.parent is None: | |
self._root.add(child_index) | |
else: | |
parent.children.add(child_index) | |
# if child was the last_node, no need to update it twice | |
if child_index != last_node_index: | |
self._update_nodes(child, child_index) # cache(s) will be updated | |
if node.parent is not None: | |
# if parent was the last_node, no need to update it twice | |
if node.parent != node_index: | |
self._update_nodes(parent, node.parent) | |
# if the last node was not the removed node, we need to fix it | |
if last_node_index != node_index: | |
# cache(s) have already been cleared | |
# fix reversed index if needed | |
if last_node.index is not None: | |
self._indexes[last_node.index] = node_index | |
# update parent's children... | |
if last_node.parent is None: | |
self._root.remove(last_node_index) | |
self._root.add(node_index) | |
else: | |
parent = self._cache_nodes(last_node.parent) | |
parent.children.remove(last_node_index) | |
parent.children.add(node_index) | |
self._update_nodes(parent, last_node.parent) # cache(s) will be updated | |
# ...and children's parent | |
for child_index in last_node.children: | |
child = self._cache_nodes(child_index) | |
child.parent = node_index | |
self._update_nodes(child, child_index) # cache(s) will be updated | |
self._update_nodes(last_node, node_index) # cache(s) will be updated (twice?) | |
# will return the real parent index | |
return node.parent | |
if len(node.children) == 0: | |
parent_index = remove_node(node, node_index) | |
if parent_index is not None: | |
parent = self._cache_nodes(parent_index) | |
if parent.index is None and len(parent.children) == 1: | |
self._remove_node(parent, parent_index) | |
else: | |
# simply update child and parent and remove node | |
_ = remove_node(node, node_index) | |
def _update_values(self, value, index): | |
""" | |
update values and cache if needed | |
""" | |
self._values[index] = self._format_value(value) | |
if "cache_values" in self._cache: | |
# better del cache than update it beacause value can be a reference to an object | |
del self._cache["cache_values"][index] | |
def _format_value(self, value): | |
""" | |
""" | |
if "compress_values" in self._config and self._config["compress_values"] is not None: | |
if self._config["compress_values"] == 0: | |
return pickle.dumps(value) | |
else: | |
return zlib.compress(pickle.dumps(value), self._config["compress_values"]) | |
else: | |
return value | |
def _format_node(self, node): | |
""" | |
""" | |
if "compress_nodes" in self._config and self._config["compress_nodes"] is not None: | |
if self._config["compress_nodes"] == 0: | |
return pickle.dumps((node.parent, node.index, list(node.children), node.key)) | |
else: | |
return zlib.compress(pickle.dumps((node.parent, node.index, list(node.children), node.key)), self._config["compress_nodes"]) | |
else: | |
return (node.parent, node.index, list(node.children), node.key) | |
def _setup(self): | |
""" | |
Parse and update config with the default values of the class | |
""" | |
# the cache_keys option has been deprecated because it's too complex to implement | |
# and requires so many operations it probably won't be helpful. | |
del self._config["cache_keys"] | |
self.set_default_factory(self.default_factory) | |
self._cache = dict() | |
for k, v in self._config.items(): | |
if k.startswith("cache_") and v is not None: | |
self._set_cache(k, v) | |
if __name__ == '__main__': | |
# Run some tests | |
from random import choice, randint | |
def run_test(obj): | |
logger.info("running tests... (CTRL-C to stop)") | |
abc = "abcdef" | |
while 1: | |
try: | |
x = "".join([choice(abc) for _ in range(randint(1, 8))]) | |
logger.info("adding %s" % x) | |
obj[x] = x | |
if randint(0, len(obj)) > 0: | |
i = randint(0, len(obj) - 1) | |
logger.info("removing k[%d] (%s)" % (i, obj[i])) | |
del obj[i] | |
try: | |
obj.check_integrity() | |
except AssertionError as e: | |
logger.error(e) | |
break | |
except KeyboardInterrupt: | |
logger.info("success!") | |
break | |
print("%r" % obj) | |
print(obj) | |
obj.show_tree() | |
obj.cache_info() | |
k = kik() | |
run_test(k) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment