Created
August 29, 2020 12:34
-
-
Save hemebond/7ec81a9d202437c0fb7919be389f892e to your computer and use it in GitHub Desktop.
A custom YAML contructor that does a deep merge of dicts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import yaml | |
from copy import copy, deepcopy | |
from yaml.nodes import MappingNode | |
from yaml.loader import Loader | |
from yaml.constructor import SafeConstructor | |
# Copyright Ferry Boender, released under the MIT license. | |
def deepupdate(tgt, src): | |
"""Deep update target dict with src | |
For each k,v in src: if k doesn't exist in target, it is deep copied from | |
src to target. Otherwise, if v is a list, target[k] is extended with | |
src[k]. If v is a set, target[k] is updated with v, If v is a dict, | |
recursively deep-update it. | |
Examples: | |
>>> t = {'name': 'Ferry', 'hobbies': ['programming', 'sci-fi']} | |
>>> deepupdate(t, {'hobbies': ['gaming']}) | |
>>> print t | |
{'name': 'Ferry', 'hobbies': ['programming', 'sci-fi', 'gaming']} | |
""" | |
target = deepcopy(tgt) | |
for k, v in src.items(): | |
if type(v) == list: | |
if not k in target: | |
target[k] = deepcopy(v) | |
else: | |
target[k].extend(v) | |
elif type(v) == dict: | |
if not k in target: | |
target[k] = deepcopy(v) | |
else: | |
deepupdate(target[k], v) | |
elif type(v) == set: | |
if not k in target: | |
target[k] = v.copy() | |
else: | |
target[k].update(v.copy()) | |
else: | |
target[k] = copy(v) | |
return target | |
class Constructor(SafeConstructor): | |
""" | |
Customise the mapping constructor to do a deep merge instead | |
of the regular shallow merge | |
""" | |
def construct_mapping(self, node, deep=False): | |
if isinstance(node, MappingNode): | |
self.flatten_mapping(node) | |
if not isinstance(node, MappingNode): | |
raise ConstructorError(None, None, | |
"expected a mapping node, but found %s" % node.id, | |
node.start_mark) | |
mapping = {} | |
for key_node, value_node in node.value: | |
key = self.construct_object(key_node, deep=deep) | |
try: | |
hash(key) | |
except TypeError as exc: | |
raise ConstructorError("while constructing a mapping", node.start_mark, | |
"found unacceptable key (%s)" % exc, key_node.start_mark) | |
value = self.construct_object(value_node, deep=True) | |
if key in mapping: | |
# Do a deep merge | |
if isinstance(value, dict) and isinstance(mapping[key], dict): | |
mapping[key] = deepupdate(mapping[key], value) | |
else: | |
mapping[key] = value | |
else: | |
mapping[key] = value | |
return mapping | |
class CustomLoader(Loader, Constructor): | |
pass | |
print(yaml.load(open('sample.yaml', 'r'), Loader=CustomLoader)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment