Skip to content

Instantly share code, notes, and snippets.

@Hrissimir
Created April 5, 2023 09:38
Show Gist options
  • Save Hrissimir/2c3c81dafe2ee00d1a4586daaeebf70f to your computer and use it in GitHub Desktop.
Save Hrissimir/2c3c81dafe2ee00d1a4586daaeebf70f to your computer and use it in GitHub Desktop.
"""Encoders for 'any' type of data that is not supported by default (+tests)."""
import datetime
import decimal
import enum
import inspect
import ipaddress
import json
import pathlib
import sqlite3
import textwrap
import types
import typing
import unittest
class AnyJSONEncoderBase(json.JSONEncoder):
"""Base JSONEncoder for 'any' type of data that is not supported by default."""
SEPARATORS = (", ", ": ")
def __init__(self, **kwargs):
kwargs.setdefault("skipkeys", False)
kwargs.setdefault("ensure_ascii", True)
kwargs.setdefault("check_circular", True)
kwargs.setdefault("allow_nan", True)
kwargs.setdefault("sort_keys", False)
kwargs.setdefault("separators", self.SEPARATORS)
super().__init__(**kwargs)
def default(self, o: typing.Any) -> typing.Any:
"""Called when object could not be serialized with default approach."""
if isinstance(o, (bool, int, float, str, list, tuple, types.NoneType)):
return o
if isinstance(o, (set, frozenset)):
return tuple(sorted(o))
if isinstance(o, type):
return str(o)
if isinstance(o, bytes):
return o.decode()
if isinstance(o, complex):
return o.real, o.imag
if isinstance(o, decimal.Decimal):
return float(o)
if isinstance(o, enum.Enum):
return o.name
if isinstance(o, datetime.datetime):
if o.tzinfo is not None:
o = o.astimezone(datetime.timezone.utc)
return o.strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z"
if isinstance(o, datetime.date):
return o.isoformat()
if isinstance(o, datetime.time):
return o.strftime("%H:%M:%S.%f")
if isinstance(o, datetime.timedelta):
return o.total_seconds()
if isinstance(o, pathlib.Path):
return str(o)
if isinstance(o, sqlite3.Row):
return tuple(o)
if isinstance(o, sqlite3.Cursor):
return list(o)
if isinstance(o, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
return str(o)
if isinstance(o, Exception):
msg = getattr(o, "msg", str(o))
args = getattr(o, "args", ())
return f"<{type(o).__name__}(msg={msg!r}, args={args!r})>"
if isinstance(o, typing.Mapping):
result = {}
for k, v in o.items():
if isinstance(k, (bool, int, float, str, types.NoneType)):
key = k
else:
key = repr(k)
value = self.default(v)
result[key] = value
return result
if isinstance(o, (types.GeneratorType, typing.Iterator, typing.Generator)):
return tuple(o)
if hasattr(o, "__json__"):
attr = getattr(o, "__json__")
if callable(attr):
value = attr()
else:
value = attr
return self.default(value)
if hasattr(o, "__getitem__") and hasattr(o, "keys"):
return dict(o)
if hasattr(o, "__dict__") and (not isinstance(o, dict)):
d = {
name: member
for name, member in inspect.getmembers(o)
if not name.startswith("__")
and not callable(member)
and not inspect.isabstract(member)
and not inspect.isbuiltin(member)
and not inspect.isfunction(member)
and not inspect.isgenerator(member)
and not inspect.isgeneratorfunction(member)
and not inspect.ismethod(member)
and not inspect.ismethoddescriptor(member)
and not inspect.isroutine(member)
}
return self.default(d)
try:
iterable = iter(o)
except TypeError:
pass
else:
return list(iterable)
# Let the base class default method raise the TypeError
return super().default(o)
# return json.JSONEncoder.default(self, o)
class AnyJSONEncoderCompact(AnyJSONEncoderBase):
"""JSON encoder for 'any' data with 'compact' output format."""
SEPARATORS = (",", ":")
def __init__(self, **kwargs):
kwargs.setdefault("indent", None)
kwargs.setdefault("separators", self.SEPARATORS)
super().__init__(**kwargs)
class AnyJSONEncoderPretty(AnyJSONEncoderBase):
"""JSON encoder for 'any' data with 'pretty' output format."""
SEPARATORS = (",", ": ")
def __init__(self, **kwargs):
kwargs.setdefault("indent", 4)
kwargs.setdefault("separators", self.SEPARATORS)
super().__init__(**kwargs)
class AnyJSONEncodersTests(unittest.TestCase):
"""Unit-tests for AnyJSONEncoderBase and it's subclasses."""
encoder = AnyJSONEncoderBase()
def test_default_handles_bool(self):
self.assertIs(True, self.encoder.default(True))
self.assertIs(False, self.encoder.default(False))
def test_default_handles_int(self):
value = 1
self.assertIs(value, self.encoder.default(value))
def test_default_handles_float(self):
value = 1.00
self.assertIs(value, self.encoder.default(value))
def test_default_handles_str(self):
value = "1.00"
self.assertIs(value, self.encoder.default(value))
def test_default_handles_list(self):
value = [1, 2]
self.assertIs(value, self.encoder.default(value))
def test_default_handles_tuple(self):
value = (1, 2)
self.assertIs(value, self.encoder.default(value))
def test_default_handles_none(self):
value = None
self.assertIs(value, self.encoder.default(value))
def test_default_handles_set(self):
value = {1, 2, 2, 1, 1, 3}
expected = (1, 2, 3)
actual = self.encoder.default(value)
self.assertIsInstance(actual, tuple)
self.assertTupleEqual(expected, actual)
def test_default_handles_frozenset(self):
value = frozenset([1, 2, 2, 1, 1, 3])
expected = (1, 2, 3)
actual = self.encoder.default(value)
self.assertIsInstance(actual, tuple)
self.assertTupleEqual(expected, actual)
def test_default_handles_type(self):
expected = str(object)
actual = self.encoder.default(object)
self.assertEqual(expected, actual)
def test_default_handles_bytes(self):
expected = "value"
actual = self.encoder.default(expected.encode())
self.assertEqual(expected, actual)
def test_default_handles_complex(self):
value = complex(1.2, 3.4)
expected = (1.2, 3.4)
actual = self.encoder.default(value)
self.assertIsInstance(actual, tuple)
self.assertTupleEqual(expected, actual)
def test_default_handles_decimal(self):
value = decimal.Decimal("0.1")
expected = 0.1
actual = self.encoder.default(value)
self.assertEqual(expected, actual)
def test_default_handles_enum(self):
class TestEnum(enum.Enum):
VALUE_1 = enum.auto()
VALUE_2 = enum.auto()
value = TestEnum.VALUE_2
expected = "VALUE_2"
actual = self.encoder.default(value)
self.assertEqual(expected, actual)
def test_default_handles_datetime_naive(self):
value = datetime.datetime(2001, 12, 31, 21, 44)
expected = "2001-12-31T21:44:00.000000Z"
actual = self.encoder.default(value)
self.assertEqual(expected, actual)
def test_default_handles_datetime_aware(self):
value = datetime.datetime(2001, 12, 31, 21, 44, tzinfo=datetime.timezone.utc)
expected = "2001-12-31T21:44:00.000000Z"
actual = self.encoder.default(value)
self.assertEqual(expected, actual)
def test_default_handles_date(self):
value = datetime.date(2001, 12, 31)
expected = "2001-12-31"
actual = self.encoder.default(value)
self.assertEqual(expected, actual)
def test_default_handles_time(self):
value = datetime.time(21, 44)
expected = "21:44:00.000000"
actual = self.encoder.default(value)
self.assertEqual(expected, actual)
def test_default_handles_timedelta(self):
value = datetime.timedelta(milliseconds=500)
expected = 0.5
actual = self.encoder.default(value)
self.assertEqual(expected, actual)
def test_default_handles_path(self):
value = pathlib.Path(r"C:\Windows")
expected = r"C:\Windows"
actual = self.encoder.default(value)
self.assertEqual(expected, actual)
def test_default_handles_ip_address_v4(self):
value = ipaddress.IPv4Address("127.0.0.1")
expected = "127.0.0.1"
actual = self.encoder.default(value)
self.assertEqual(expected, actual)
def test_default_handles_ip_address_v6(self):
value = ipaddress.IPv6Address("2001:db8:3333:4444:5555:6666:7777:8888")
expected = "2001:db8:3333:4444:5555:6666:7777:8888"
actual = self.encoder.default(value)
self.assertEqual(expected, actual)
def test_default_handles_exception(self):
value = Exception("wow")
expected = "<Exception(msg='wow', args=('wow',))>"
actual = self.encoder.default(value)
self.assertEqual(expected, actual)
def test_default_handles_mapping_with_bad_keys(self):
value = {object: 1}
expected = {"<class 'object'>": 1}
actual = self.encoder.default(value)
self.assertIsInstance(actual, dict)
self.assertEqual(expected, actual)
def test_default_handles_generators(self):
def gen_func():
for i in range(3):
yield i
value = gen_func()
expected = (0, 1, 2)
actual = self.encoder.default(value)
self.assertIsInstance(actual, tuple)
self.assertEqual(expected, actual)
def test_default_handles_json_dunder_attr(self):
class MyCls:
def __init__(self, a):
self.a = a
self.__json__ = {"a": self.a}
value = MyCls(42)
expected = {"a": 42}
actual = self.encoder.default(value)
self.assertIsInstance(actual, dict)
self.assertEqual(expected, actual)
expected_json = '{"a": 42}'
actual_json = self.encoder.encode(value)
self.assertEqual(expected_json, actual_json)
def test_default_handles_json_dunder_func(self):
class MyCls:
def __init__(self, a):
self.a = a
def __json__(self):
return {"a": self.a}
value = MyCls(42)
expected = {"a": 42}
actual = self.encoder.default(value)
self.assertIsInstance(actual, dict)
self.assertEqual(expected, actual)
expected_json = '{"a": 42}'
actual_json = self.encoder.encode(value)
self.assertEqual(expected_json, actual_json)
def test_default_handles_dict_like(self):
class DictLike:
def __init__(self, data: dict):
self.data = data
def keys(self):
return self.data.keys()
def __getitem__(self, item):
return self.data[item]
value = DictLike({"a": 1, "b": 2})
expected = {"a": 1, "b": 2}
actual = self.encoder.default(value)
self.assertIsInstance(actual, dict)
self.assertEqual(expected, actual)
def test_default_handles_custom_classes(self):
class Person:
def __init__(self, name: str, age: int):
self.name = name
self.age = age
value = Person("Santa", 42)
expected = {"age": 42, "name": "Santa"}
actual = self.encoder.default(value)
self.assertIsInstance(actual, dict)
self.assertEqual(expected, actual)
def test_default_handles_iterators(self):
items = [1, 2, 3]
value = iter(items)
expected = (1, 2, 3)
actual = self.encoder.default(value)
self.assertIsInstance(actual, tuple)
self.assertEqual(expected, actual)
def test_encode_base(self):
expected = '[{"key": "value"}]'
actual = AnyJSONEncoderBase().encode([dict(key="value")])
self.assertEqual(expected, actual)
def test_encode_compact(self):
expected = '[{"key":"value"}]'
actual = AnyJSONEncoderCompact().encode([dict(key="value")])
self.assertEqual(expected, actual)
def test_encode_pretty(self):
expected = textwrap.dedent(
"""\
[
{
"key": "value"
}
]"""
)
actual = AnyJSONEncoderPretty().encode([dict(key="value")])
self.assertEqual(expected, actual)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment