Created
August 13, 2017 05:18
-
-
Save pstch/4a625d4d8e4de4c674ab75125684f4e4 to your computer and use it in GitHub Desktop.
Python deepmerge algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
"""Deep merge algorithm -- Merges multiple values, combining dictionaries by | |
theirs keys, tuples by their indices, chaining lists and sets, and | |
replacing by the last merged value for any other type. | |
""" | |
# 0. Imports & metadata | |
# ============================================================================= | |
from itertools import chain | |
__author__ = 'Hugo Geoffroy' | |
__copyright__ = "Copyright 2017, Hugo Geoffroy" | |
__credits__ = ['EDEN SAS'] | |
__license__ = "GPLv3" | |
__version__ = "0.0.1" | |
__maintainer__ = "pistache" | |
__email__ = "[email protected]" | |
__status__ = "Development" | |
# 1. Visitors and mergers | |
# ============================================================================= | |
# | |
# 1.1. Visitor implementations for each merge strategy | |
# ----------------------------------------------------------------------------- | |
def _visit_mapped(sources): | |
"""Merge same-keyed values in mapped sources. """ | |
mapped = {} | |
for source in sources: | |
for key, value in source.items(): | |
mapped.setdefault(key, []).append((value)) | |
return mapped.items(), True, True | |
def _visit_zipped(sources): | |
"""Merge same-index values in zipped sources. """ | |
counted = [] | |
for source in sources: | |
while len(source) > len(counted): | |
counted.append([]) | |
for index, value in enumerate(source): | |
counted[index].append(value) | |
return counted, True, False | |
def _visit_nested(sources): | |
"""Chain sources, recursing to rebuild them. """ | |
nested = [] | |
for source in sources: | |
for value in source: | |
nested.append([value]) | |
return nested, True, False | |
def _visit_chained(sources): | |
"""Chain sources, recursing to rebuild them. """ | |
return chain(*sources), False, False | |
def _visit_default(sources): | |
"""Return last source value. """ | |
return sources[-1], False, False | |
# 1.2. Visitor mapping and dispatch function | |
# ----------------------------------------------------------------------------- | |
DEFAULT_VISITOR = _visit_default | |
VISITOR_MAPPING = { | |
dict: _visit_mapped, | |
tuple: _visit_zipped, | |
list: _visit_nested, | |
set: _visit_chained, | |
} | |
# 2. Deep merge algorithm | |
# ============================================================================= | |
# 2.1. Recursive implementation | |
# ----------------------------------------------------------------------------- | |
def deepmerge_recursive(source, *sources, **kwargs): | |
"""Deep merge algorithm (recursive form). | |
""" | |
# read keyword arguments | |
default_visitor = kwargs.pop('default_visitor', DEFAULT_VISITOR) | |
visitor_mapping = kwargs.pop('visitor_mapping', VISITOR_MAPPING) | |
if kwargs: | |
raise TypeError( | |
"deepmerge_recursive() got unexpected keyworg arguments {}" | |
"".format(', '.join(map(repr, kwargs.keys()))) | |
) | |
# visit source values and call visitor | |
sources = list(chain([source], sources)) | |
stype = type(sources[0]) | |
visitor = visitor_mapping.get(stype, default_visitor) | |
value, _nested, _mapped = visitor(sources) | |
# recurse if nested | |
if _nested: | |
accum = [] | |
for _value in value: | |
_key, _value = _value if _mapped else (None, _value) | |
_value = deepmerge_recursive(*_value) | |
accum.append((_key, _value) if _mapped else _value) | |
value = accum | |
# build result | |
return stype(value) | |
# 2.2. Iterative implementation | |
# ----------------------------------------------------------------------------- | |
def deepmerge_iterative(source, *sources, **kwargs): | |
"""Deep merge algorithm (iterative form). | |
""" | |
# read keyword arguments | |
default_visitor = kwargs.pop('default_visitor', DEFAULT_VISITOR) | |
visitor_mapping = kwargs.pop('visitor_mapping', VISITOR_MAPPING) | |
if kwargs: | |
raise TypeError( | |
"deepmerge_iterative() got unexpected keyworg arguments {}" | |
"".format(', '.join(map(repr, kwargs.keys()))) | |
) | |
# create stack and initial frame | |
result, sources = [], list(chain([source], sources)) | |
stack = [(sources, result, False, None, None)] | |
while stack: | |
sources, accum, mapped, key, stype = stack.pop() | |
if stype is None: | |
stype = type(sources[0]) | |
visitor = visitor_mapping.get(stype, default_visitor) | |
value, _nested, _mapped = visitor(sources) | |
if _nested: | |
_accum = [] | |
stack.append((_accum, accum, mapped, key, stype)) | |
for _value in reversed(value): | |
_key, _value = _value if _mapped else (None, _value) | |
stack.append((_value, _accum, _mapped, _key, None)) | |
_sources = (_key, _value) if _mapped else _value | |
continue | |
else: | |
value = sources | |
value = stype(value) | |
accum.append((key, value) if mapped else value) | |
return result[0] | |
# 3. Test suite | |
# ============================================================================= | |
TEST_SOURCES = [ | |
[1, 2, 3, 4], | |
[[1, 2], [2, 3], [3, 4]], | |
[(1, 2), (2, 3), (3, 4)], | |
[([1], ["a"]), ([2], ["b"]), ([3], ["c"])], | |
[{1, 2}, {2, 3}, {3, 4}], | |
[{1: 2}, {2: 3}, {3: 4}], | |
[{1: 1}, {1: 2}, {1: 3}], | |
[{'foo': [1, 2, 3], 'spam': 'egg'}, | |
{'foo': [2, 3, 4], 'spam': 'fish'}], | |
] | |
TEST_RESULTS = [ | |
4, | |
[1, 2, 2, 3, 3, 4], | |
(3, 4), | |
([1, 2, 3], ["a", "b", "c"]), | |
{1, 2, 3, 4}, | |
{1: 2, 2: 3, 3: 4}, | |
{1: 3}, | |
{'foo': [1, 2, 3, 2, 3, 4], 'spam': 'fish'} | |
] | |
def test_deepmerge_recursive(): | |
"""Test :func:`deepmerge_recursive` for expected merge results. """ | |
for sources, result in zip(TEST_SOURCES, TEST_RESULTS): | |
merged = deepmerge_recursive(*sources) | |
assert merged == result | |
def test_deepmerge_iterative(): | |
"""Test :func:`deepmerge_iterative` for expected merge results. """ | |
for sources, result in zip(TEST_SOURCES, TEST_RESULTS): | |
merged = deepmerge_iterative(*sources) | |
assert merged == result | |
# 4. Entry point | |
# ============================================================================= | |
def main(): | |
"""Main entry point. """ | |
test_deepmerge_recursive() | |
test_deepmerge_iterative() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment