Skip to content

Instantly share code, notes, and snippets.

@pstch
Created August 13, 2017 05:18
Show Gist options
  • Save pstch/4a625d4d8e4de4c674ab75125684f4e4 to your computer and use it in GitHub Desktop.
Save pstch/4a625d4d8e4de4c674ab75125684f4e4 to your computer and use it in GitHub Desktop.
Python deepmerge algorithm
#!/usr/bin/env python
"""Deep merge algorithm -- Merges multiple values, combining dictionaries by
theirs keys, tuples by their indices, chaining lists and sets, and
replacing by the last merged value for any other type.
"""
# 0. Imports & metadata
# =============================================================================
from itertools import chain
__author__ = 'Hugo Geoffroy'
__copyright__ = "Copyright 2017, Hugo Geoffroy"
__credits__ = ['EDEN SAS']
__license__ = "GPLv3"
__version__ = "0.0.1"
__maintainer__ = "pistache"
__email__ = "[email protected]"
__status__ = "Development"
# 1. Visitors and mergers
# =============================================================================
#
# 1.1. Visitor implementations for each merge strategy
# -----------------------------------------------------------------------------
def _visit_mapped(sources):
"""Merge same-keyed values in mapped sources. """
mapped = {}
for source in sources:
for key, value in source.items():
mapped.setdefault(key, []).append((value))
return mapped.items(), True, True
def _visit_zipped(sources):
"""Merge same-index values in zipped sources. """
counted = []
for source in sources:
while len(source) > len(counted):
counted.append([])
for index, value in enumerate(source):
counted[index].append(value)
return counted, True, False
def _visit_nested(sources):
"""Chain sources, recursing to rebuild them. """
nested = []
for source in sources:
for value in source:
nested.append([value])
return nested, True, False
def _visit_chained(sources):
"""Chain sources, recursing to rebuild them. """
return chain(*sources), False, False
def _visit_default(sources):
"""Return last source value. """
return sources[-1], False, False
# 1.2. Visitor mapping and dispatch function
# -----------------------------------------------------------------------------
DEFAULT_VISITOR = _visit_default
VISITOR_MAPPING = {
dict: _visit_mapped,
tuple: _visit_zipped,
list: _visit_nested,
set: _visit_chained,
}
# 2. Deep merge algorithm
# =============================================================================
# 2.1. Recursive implementation
# -----------------------------------------------------------------------------
def deepmerge_recursive(source, *sources, **kwargs):
"""Deep merge algorithm (recursive form).
"""
# read keyword arguments
default_visitor = kwargs.pop('default_visitor', DEFAULT_VISITOR)
visitor_mapping = kwargs.pop('visitor_mapping', VISITOR_MAPPING)
if kwargs:
raise TypeError(
"deepmerge_recursive() got unexpected keyworg arguments {}"
"".format(', '.join(map(repr, kwargs.keys())))
)
# visit source values and call visitor
sources = list(chain([source], sources))
stype = type(sources[0])
visitor = visitor_mapping.get(stype, default_visitor)
value, _nested, _mapped = visitor(sources)
# recurse if nested
if _nested:
accum = []
for _value in value:
_key, _value = _value if _mapped else (None, _value)
_value = deepmerge_recursive(*_value)
accum.append((_key, _value) if _mapped else _value)
value = accum
# build result
return stype(value)
# 2.2. Iterative implementation
# -----------------------------------------------------------------------------
def deepmerge_iterative(source, *sources, **kwargs):
"""Deep merge algorithm (iterative form).
"""
# read keyword arguments
default_visitor = kwargs.pop('default_visitor', DEFAULT_VISITOR)
visitor_mapping = kwargs.pop('visitor_mapping', VISITOR_MAPPING)
if kwargs:
raise TypeError(
"deepmerge_iterative() got unexpected keyworg arguments {}"
"".format(', '.join(map(repr, kwargs.keys())))
)
# create stack and initial frame
result, sources = [], list(chain([source], sources))
stack = [(sources, result, False, None, None)]
while stack:
sources, accum, mapped, key, stype = stack.pop()
if stype is None:
stype = type(sources[0])
visitor = visitor_mapping.get(stype, default_visitor)
value, _nested, _mapped = visitor(sources)
if _nested:
_accum = []
stack.append((_accum, accum, mapped, key, stype))
for _value in reversed(value):
_key, _value = _value if _mapped else (None, _value)
stack.append((_value, _accum, _mapped, _key, None))
_sources = (_key, _value) if _mapped else _value
continue
else:
value = sources
value = stype(value)
accum.append((key, value) if mapped else value)
return result[0]
# 3. Test suite
# =============================================================================
TEST_SOURCES = [
[1, 2, 3, 4],
[[1, 2], [2, 3], [3, 4]],
[(1, 2), (2, 3), (3, 4)],
[([1], ["a"]), ([2], ["b"]), ([3], ["c"])],
[{1, 2}, {2, 3}, {3, 4}],
[{1: 2}, {2: 3}, {3: 4}],
[{1: 1}, {1: 2}, {1: 3}],
[{'foo': [1, 2, 3], 'spam': 'egg'},
{'foo': [2, 3, 4], 'spam': 'fish'}],
]
TEST_RESULTS = [
4,
[1, 2, 2, 3, 3, 4],
(3, 4),
([1, 2, 3], ["a", "b", "c"]),
{1, 2, 3, 4},
{1: 2, 2: 3, 3: 4},
{1: 3},
{'foo': [1, 2, 3, 2, 3, 4], 'spam': 'fish'}
]
def test_deepmerge_recursive():
"""Test :func:`deepmerge_recursive` for expected merge results. """
for sources, result in zip(TEST_SOURCES, TEST_RESULTS):
merged = deepmerge_recursive(*sources)
assert merged == result
def test_deepmerge_iterative():
"""Test :func:`deepmerge_iterative` for expected merge results. """
for sources, result in zip(TEST_SOURCES, TEST_RESULTS):
merged = deepmerge_iterative(*sources)
assert merged == result
# 4. Entry point
# =============================================================================
def main():
"""Main entry point. """
test_deepmerge_recursive()
test_deepmerge_iterative()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment