Created
April 5, 2023 09:38
-
-
Save Hrissimir/2c3c81dafe2ee00d1a4586daaeebf70f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Encoders for 'any' type of data that is not supported by default (+tests).""" | |
import datetime | |
import decimal | |
import enum | |
import inspect | |
import ipaddress | |
import json | |
import pathlib | |
import sqlite3 | |
import textwrap | |
import types | |
import typing | |
import unittest | |
class AnyJSONEncoderBase(json.JSONEncoder): | |
"""Base JSONEncoder for 'any' type of data that is not supported by default.""" | |
SEPARATORS = (", ", ": ") | |
def __init__(self, **kwargs): | |
kwargs.setdefault("skipkeys", False) | |
kwargs.setdefault("ensure_ascii", True) | |
kwargs.setdefault("check_circular", True) | |
kwargs.setdefault("allow_nan", True) | |
kwargs.setdefault("sort_keys", False) | |
kwargs.setdefault("separators", self.SEPARATORS) | |
super().__init__(**kwargs) | |
def default(self, o: typing.Any) -> typing.Any: | |
"""Called when object could not be serialized with default approach.""" | |
if isinstance(o, (bool, int, float, str, list, tuple, types.NoneType)): | |
return o | |
if isinstance(o, (set, frozenset)): | |
return tuple(sorted(o)) | |
if isinstance(o, type): | |
return str(o) | |
if isinstance(o, bytes): | |
return o.decode() | |
if isinstance(o, complex): | |
return o.real, o.imag | |
if isinstance(o, decimal.Decimal): | |
return float(o) | |
if isinstance(o, enum.Enum): | |
return o.name | |
if isinstance(o, datetime.datetime): | |
if o.tzinfo is not None: | |
o = o.astimezone(datetime.timezone.utc) | |
return o.strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z" | |
if isinstance(o, datetime.date): | |
return o.isoformat() | |
if isinstance(o, datetime.time): | |
return o.strftime("%H:%M:%S.%f") | |
if isinstance(o, datetime.timedelta): | |
return o.total_seconds() | |
if isinstance(o, pathlib.Path): | |
return str(o) | |
if isinstance(o, sqlite3.Row): | |
return tuple(o) | |
if isinstance(o, sqlite3.Cursor): | |
return list(o) | |
if isinstance(o, (ipaddress.IPv4Address, ipaddress.IPv6Address)): | |
return str(o) | |
if isinstance(o, Exception): | |
msg = getattr(o, "msg", str(o)) | |
args = getattr(o, "args", ()) | |
return f"<{type(o).__name__}(msg={msg!r}, args={args!r})>" | |
if isinstance(o, typing.Mapping): | |
result = {} | |
for k, v in o.items(): | |
if isinstance(k, (bool, int, float, str, types.NoneType)): | |
key = k | |
else: | |
key = repr(k) | |
value = self.default(v) | |
result[key] = value | |
return result | |
if isinstance(o, (types.GeneratorType, typing.Iterator, typing.Generator)): | |
return tuple(o) | |
if hasattr(o, "__json__"): | |
attr = getattr(o, "__json__") | |
if callable(attr): | |
value = attr() | |
else: | |
value = attr | |
return self.default(value) | |
if hasattr(o, "__getitem__") and hasattr(o, "keys"): | |
return dict(o) | |
if hasattr(o, "__dict__") and (not isinstance(o, dict)): | |
d = { | |
name: member | |
for name, member in inspect.getmembers(o) | |
if not name.startswith("__") | |
and not callable(member) | |
and not inspect.isabstract(member) | |
and not inspect.isbuiltin(member) | |
and not inspect.isfunction(member) | |
and not inspect.isgenerator(member) | |
and not inspect.isgeneratorfunction(member) | |
and not inspect.ismethod(member) | |
and not inspect.ismethoddescriptor(member) | |
and not inspect.isroutine(member) | |
} | |
return self.default(d) | |
try: | |
iterable = iter(o) | |
except TypeError: | |
pass | |
else: | |
return list(iterable) | |
# Let the base class default method raise the TypeError | |
return super().default(o) | |
# return json.JSONEncoder.default(self, o) | |
class AnyJSONEncoderCompact(AnyJSONEncoderBase): | |
"""JSON encoder for 'any' data with 'compact' output format.""" | |
SEPARATORS = (",", ":") | |
def __init__(self, **kwargs): | |
kwargs.setdefault("indent", None) | |
kwargs.setdefault("separators", self.SEPARATORS) | |
super().__init__(**kwargs) | |
class AnyJSONEncoderPretty(AnyJSONEncoderBase): | |
"""JSON encoder for 'any' data with 'pretty' output format.""" | |
SEPARATORS = (",", ": ") | |
def __init__(self, **kwargs): | |
kwargs.setdefault("indent", 4) | |
kwargs.setdefault("separators", self.SEPARATORS) | |
super().__init__(**kwargs) | |
class AnyJSONEncodersTests(unittest.TestCase): | |
"""Unit-tests for AnyJSONEncoderBase and it's subclasses.""" | |
encoder = AnyJSONEncoderBase() | |
def test_default_handles_bool(self): | |
self.assertIs(True, self.encoder.default(True)) | |
self.assertIs(False, self.encoder.default(False)) | |
def test_default_handles_int(self): | |
value = 1 | |
self.assertIs(value, self.encoder.default(value)) | |
def test_default_handles_float(self): | |
value = 1.00 | |
self.assertIs(value, self.encoder.default(value)) | |
def test_default_handles_str(self): | |
value = "1.00" | |
self.assertIs(value, self.encoder.default(value)) | |
def test_default_handles_list(self): | |
value = [1, 2] | |
self.assertIs(value, self.encoder.default(value)) | |
def test_default_handles_tuple(self): | |
value = (1, 2) | |
self.assertIs(value, self.encoder.default(value)) | |
def test_default_handles_none(self): | |
value = None | |
self.assertIs(value, self.encoder.default(value)) | |
def test_default_handles_set(self): | |
value = {1, 2, 2, 1, 1, 3} | |
expected = (1, 2, 3) | |
actual = self.encoder.default(value) | |
self.assertIsInstance(actual, tuple) | |
self.assertTupleEqual(expected, actual) | |
def test_default_handles_frozenset(self): | |
value = frozenset([1, 2, 2, 1, 1, 3]) | |
expected = (1, 2, 3) | |
actual = self.encoder.default(value) | |
self.assertIsInstance(actual, tuple) | |
self.assertTupleEqual(expected, actual) | |
def test_default_handles_type(self): | |
expected = str(object) | |
actual = self.encoder.default(object) | |
self.assertEqual(expected, actual) | |
def test_default_handles_bytes(self): | |
expected = "value" | |
actual = self.encoder.default(expected.encode()) | |
self.assertEqual(expected, actual) | |
def test_default_handles_complex(self): | |
value = complex(1.2, 3.4) | |
expected = (1.2, 3.4) | |
actual = self.encoder.default(value) | |
self.assertIsInstance(actual, tuple) | |
self.assertTupleEqual(expected, actual) | |
def test_default_handles_decimal(self): | |
value = decimal.Decimal("0.1") | |
expected = 0.1 | |
actual = self.encoder.default(value) | |
self.assertEqual(expected, actual) | |
def test_default_handles_enum(self): | |
class TestEnum(enum.Enum): | |
VALUE_1 = enum.auto() | |
VALUE_2 = enum.auto() | |
value = TestEnum.VALUE_2 | |
expected = "VALUE_2" | |
actual = self.encoder.default(value) | |
self.assertEqual(expected, actual) | |
def test_default_handles_datetime_naive(self): | |
value = datetime.datetime(2001, 12, 31, 21, 44) | |
expected = "2001-12-31T21:44:00.000000Z" | |
actual = self.encoder.default(value) | |
self.assertEqual(expected, actual) | |
def test_default_handles_datetime_aware(self): | |
value = datetime.datetime(2001, 12, 31, 21, 44, tzinfo=datetime.timezone.utc) | |
expected = "2001-12-31T21:44:00.000000Z" | |
actual = self.encoder.default(value) | |
self.assertEqual(expected, actual) | |
def test_default_handles_date(self): | |
value = datetime.date(2001, 12, 31) | |
expected = "2001-12-31" | |
actual = self.encoder.default(value) | |
self.assertEqual(expected, actual) | |
def test_default_handles_time(self): | |
value = datetime.time(21, 44) | |
expected = "21:44:00.000000" | |
actual = self.encoder.default(value) | |
self.assertEqual(expected, actual) | |
def test_default_handles_timedelta(self): | |
value = datetime.timedelta(milliseconds=500) | |
expected = 0.5 | |
actual = self.encoder.default(value) | |
self.assertEqual(expected, actual) | |
def test_default_handles_path(self): | |
value = pathlib.Path(r"C:\Windows") | |
expected = r"C:\Windows" | |
actual = self.encoder.default(value) | |
self.assertEqual(expected, actual) | |
def test_default_handles_ip_address_v4(self): | |
value = ipaddress.IPv4Address("127.0.0.1") | |
expected = "127.0.0.1" | |
actual = self.encoder.default(value) | |
self.assertEqual(expected, actual) | |
def test_default_handles_ip_address_v6(self): | |
value = ipaddress.IPv6Address("2001:db8:3333:4444:5555:6666:7777:8888") | |
expected = "2001:db8:3333:4444:5555:6666:7777:8888" | |
actual = self.encoder.default(value) | |
self.assertEqual(expected, actual) | |
def test_default_handles_exception(self): | |
value = Exception("wow") | |
expected = "<Exception(msg='wow', args=('wow',))>" | |
actual = self.encoder.default(value) | |
self.assertEqual(expected, actual) | |
def test_default_handles_mapping_with_bad_keys(self): | |
value = {object: 1} | |
expected = {"<class 'object'>": 1} | |
actual = self.encoder.default(value) | |
self.assertIsInstance(actual, dict) | |
self.assertEqual(expected, actual) | |
def test_default_handles_generators(self): | |
def gen_func(): | |
for i in range(3): | |
yield i | |
value = gen_func() | |
expected = (0, 1, 2) | |
actual = self.encoder.default(value) | |
self.assertIsInstance(actual, tuple) | |
self.assertEqual(expected, actual) | |
def test_default_handles_json_dunder_attr(self): | |
class MyCls: | |
def __init__(self, a): | |
self.a = a | |
self.__json__ = {"a": self.a} | |
value = MyCls(42) | |
expected = {"a": 42} | |
actual = self.encoder.default(value) | |
self.assertIsInstance(actual, dict) | |
self.assertEqual(expected, actual) | |
expected_json = '{"a": 42}' | |
actual_json = self.encoder.encode(value) | |
self.assertEqual(expected_json, actual_json) | |
def test_default_handles_json_dunder_func(self): | |
class MyCls: | |
def __init__(self, a): | |
self.a = a | |
def __json__(self): | |
return {"a": self.a} | |
value = MyCls(42) | |
expected = {"a": 42} | |
actual = self.encoder.default(value) | |
self.assertIsInstance(actual, dict) | |
self.assertEqual(expected, actual) | |
expected_json = '{"a": 42}' | |
actual_json = self.encoder.encode(value) | |
self.assertEqual(expected_json, actual_json) | |
def test_default_handles_dict_like(self): | |
class DictLike: | |
def __init__(self, data: dict): | |
self.data = data | |
def keys(self): | |
return self.data.keys() | |
def __getitem__(self, item): | |
return self.data[item] | |
value = DictLike({"a": 1, "b": 2}) | |
expected = {"a": 1, "b": 2} | |
actual = self.encoder.default(value) | |
self.assertIsInstance(actual, dict) | |
self.assertEqual(expected, actual) | |
def test_default_handles_custom_classes(self): | |
class Person: | |
def __init__(self, name: str, age: int): | |
self.name = name | |
self.age = age | |
value = Person("Santa", 42) | |
expected = {"age": 42, "name": "Santa"} | |
actual = self.encoder.default(value) | |
self.assertIsInstance(actual, dict) | |
self.assertEqual(expected, actual) | |
def test_default_handles_iterators(self): | |
items = [1, 2, 3] | |
value = iter(items) | |
expected = (1, 2, 3) | |
actual = self.encoder.default(value) | |
self.assertIsInstance(actual, tuple) | |
self.assertEqual(expected, actual) | |
def test_encode_base(self): | |
expected = '[{"key": "value"}]' | |
actual = AnyJSONEncoderBase().encode([dict(key="value")]) | |
self.assertEqual(expected, actual) | |
def test_encode_compact(self): | |
expected = '[{"key":"value"}]' | |
actual = AnyJSONEncoderCompact().encode([dict(key="value")]) | |
self.assertEqual(expected, actual) | |
def test_encode_pretty(self): | |
expected = textwrap.dedent( | |
"""\ | |
[ | |
{ | |
"key": "value" | |
} | |
]""" | |
) | |
actual = AnyJSONEncoderPretty().encode([dict(key="value")]) | |
self.assertEqual(expected, actual) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment