Last active
September 30, 2022 10:34
-
-
Save wonderbeyond/7982ed0894ca0c11dc77d946e6c5cf99 to your computer and use it in GitHub Desktop.
[Python][deep data] Handle (get/set) nested python dicts and lists.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Union, Any, List, Dict, Iterable | |
import functools | |
def _get(data, key: Union[str, int], default=None) -> Any: | |
""" | |
Get value from dict or list by key (integer or string). | |
""" | |
if isinstance(data, dict): | |
return data.get(key, default) | |
elif isinstance(data, list): | |
try: | |
key = int(key) | |
except ValueError: | |
return default | |
try: | |
return data[key] | |
except IndexError: | |
return default | |
else: | |
return default | |
def _glob_get(data, key: Union[str, int], default=None) -> Dict[Union[str, int], Any]: | |
""" | |
Get value from dict or list by key (integer or string). | |
""" | |
expanded_keys: Iterable[Union[str, int]] = [] | |
if key != '*': | |
expanded_keys = [key] | |
elif isinstance(data, dict): | |
expanded_keys = data.keys() | |
elif isinstance(data, list): | |
expanded_keys = range(len(data)) | |
return {k: _get(data, k, default=default) for k in expanded_keys} | |
def deep_get(data, keys, default=None) -> Any: | |
""" | |
Get values from nested dicts or lists. | |
""" | |
return functools.reduce( | |
lambda d, key: _get(d, key, default=default), | |
keys.split("."), | |
data, | |
) | |
def glob_deep_get(data, keys, default=None) -> Dict[str, Any]: | |
""" | |
Works like deep_get, but allow wildcards in keys. | |
""" | |
ret: Dict[str, Any] = {} | |
def match_and_push(path: List[Union[str, int]], data, keys: List[str]): | |
exhausted = len(keys) == 1 | |
expanded_match = _glob_get(data, keys[0], default=default) | |
for transformed_key, value in expanded_match.items(): | |
if exhausted: | |
ret['.'.join(str(p) for p in [*path, transformed_key])] = value | |
continue | |
match_and_push([*path, transformed_key], value, keys[1:]) | |
match_and_push([], data, [*keys.split('.')]) | |
return ret | |
def deep_set(d, key, value): | |
dd = d | |
keys = key.split('.') | |
end_key = keys.pop() | |
for k in keys: | |
if isinstance(dd, dict): | |
dd = dd.setdefault(k, {}) | |
elif isinstance(dd, list): | |
idx = int(k) | |
try: | |
dd = dd[idx] | |
except IndexError: | |
dd.append(dd := {}) | |
if isinstance(dd, list): | |
dd[int(end_key)] = value | |
else: | |
dd[end_key] = value | |
if __name__ == '__main__': | |
test_data = { | |
'name': 'Wonder', | |
'contact': { | |
'addresses': ['Moon', 'Mars'], | |
'phone': 110, | |
'sns': [ | |
{'platform': 'twitter', 'id': '@wonder'}, | |
{'platform': 'weibo', 'id': '@workwonder'} | |
] | |
} | |
} | |
assert deep_get(test_data, 'name') == 'Wonder' | |
assert deep_get(test_data, 'english_name') is None | |
assert deep_get(test_data, 'contact.phone') == 110 | |
assert deep_get(test_data, 'contact.phone.mobile') is None | |
assert deep_get(test_data, 'other1.other2.other3') is None | |
# Get from deep list | |
assert deep_get(test_data, 'contact.sns.1.id') == '@workwonder' | |
assert deep_get(test_data, 'contact.addresses.0') == 'Moon' | |
assert deep_get(test_data, 'contact.addresses.1') == 'Mars' | |
assert deep_get(test_data, 'contact.addresses.2') is None | |
assert deep_get(None, 'name') is None | |
assert deep_get(100, 'name') is None | |
assert deep_get('some', 'name') is None | |
assert deep_get('some', 'name', 'Wonder') == 'Wonder' | |
# Deep get by wildcards | |
assert glob_deep_get(test_data, 'contact.sns.*.platform') == { | |
'contact.sns.0.platform': 'twitter', | |
'contact.sns.1.platform': 'weibo', | |
} | |
assert glob_deep_get(test_data, 'contact.addresses.*') == { | |
'contact.addresses.0': 'Moon', | |
'contact.addresses.1': 'Mars', | |
} | |
# Deep set | |
deep_set(test_data, 'name', 'banned') | |
assert deep_get(test_data, 'name') == 'banned' | |
deep_set(test_data, 'contact.sns.1.id', 'forbidden') | |
assert deep_get(test_data, 'contact.sns.1.id') == 'forbidden' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment