Created
April 21, 2020 04:19
-
-
Save myuanz/c52caa22b0076ab65b07048a911b546e to your computer and use it in GitHub Desktop.
Dataclasses 的一个轮子 / Homemade Dataclasses
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DataclassesBase: | |
def __init__(self, **data): | |
self._annotations = self.__annotations__ if hasattr(self, "__annotations__") else {} | |
annotations = self._annotations.copy() | |
for key, value in data.items(): | |
if key in annotations: | |
if isinstance(value, annotations[key]): | |
self.__dict__[key] = value | |
annotations.pop(key) | |
else: | |
raise TypeError( | |
"The type of value of %s should be %s, but got %s" % (key, annotations[key], type(value)) | |
) | |
else: | |
warnings.warn("Got an unexpected key %s" % key) | |
for annotation in annotations: | |
error_msg = [] | |
if not hasattr(self, annotation): | |
error_msg.append("\tname: %s, type: %s" % (annotation, annotations[annotation])) | |
if error_msg: | |
raise NameError("There are some missing values\n" + "\n\b".join(error_msg)) | |
def __setattr__(self, key: Text, value: Any): | |
if not key.startswith("_"): | |
if key not in self._annotations: | |
warnings.warn("Got an unexpected key %s" % key) | |
elif not isinstance(value, self._annotations[key]): | |
raise TypeError( | |
"The type of value of %s should be %s, but got %s" % (key, self._annotations[key], type(value)) | |
) | |
self.__dict__[key] = value | |
def dict(self) -> dict: | |
ret = {} | |
for annotation in self._annotations: | |
ret[annotation] = getattr(self, annotation) | |
return ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import warnings | |
import string | |
import random | |
from dataclasses import DataclassesBase | |
def random_string(length=16): | |
rule = string.ascii_letters + string.digits | |
rand_list = random.sample(rule, length) | |
return "".join(rand_list) | |
class DataclassesBaseTestCase(unittest.TestCase): | |
class A(DataclassesBase): | |
a: int | |
b: str | |
c: dict | |
d: str = "2" | |
def test_dataclasses_init(self): | |
a = 1 | |
b = random_string() | |
c = {random_string(): random_string()} | |
a1 = self.A(a=a, b=b, c=c) | |
assert a1.a == a | |
assert a1.b == b | |
assert a1.c == c | |
assert a1.d == self.A.d | |
def test_dataclasses_todict(self): | |
a = 1 | |
b = random_string() | |
c = {random_string(): random_string()} | |
a1 = self.A(a=a, b=b, c=c) | |
assert a1.dict() == {"a": a, "b": b, "c": c, "d": self.A.d} | |
def test_unexpected_key(self): | |
with warnings.catch_warnings(record=True) as w: | |
warnings.simplefilter("always") | |
random_key = random_string() | |
self.A(a=2, b="asd", c={}, **{random_key: random_string()}) | |
assert "Got an unexpected key %s" % random_key == str(w[-1].message) | |
with warnings.catch_warnings(record=True) as w: | |
warnings.simplefilter("always") | |
random_key = random_string() | |
a1 = self.A(a=2, b="asd", c={}) | |
setattr(a1, random_key, random_string()) | |
assert "Got an unexpected key %s" % random_key == str(w[-1].message) | |
def test_type_error(self): | |
try: | |
self.A(a=2, b="asd", c={}, d=2) | |
except Exception as e: | |
assert "The type of value of d should be <class 'str'>, but got <class 'int'>" in str(e) | |
else: | |
assert "Here shoule be an TypeError, but failed to trigger it" and False | |
def test_set_type_error(self): | |
a1 = self.A(a=2, b="asd", c={}) | |
try: | |
a1.d = 2 | |
except Exception as e: | |
assert "The type of value of d should be <class 'str'>, but got <class 'int'>" in str(e) | |
else: | |
assert "Here shoule be an TypeError, but failed to trigger it" and False | |
def test_missing_values(self): | |
try: | |
self.A(a=2, c={}) | |
except Exception as e: | |
assert "There are some missing values" in str(e) | |
else: | |
assert "Here shoule be an NameError, but failed to trigger it" and False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment