Last active
May 22, 2023 07:33
-
-
Save hound672/89f0970c9033860b27b84b6f041156c6 to your computer and use it in GitHub Desktop.
Dataclass to Pydantic BaseModel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from typing import Any, Type | |
from dataclasses import dataclass, is_dataclass | |
from pydantic import BaseModel, ValidationError | |
from pydantic.main import MetaModel | |
AnyType = Type[Any] | |
def from_dataclass(DataClass: AnyType) -> BaseModel: | |
def _get_model(DataClass: AnyType) -> BaseModel: | |
for field_name, field_type in DataClass.__annotations__.items(): | |
if is_dataclass(field_type): | |
field = _get_model(field_type) | |
DataClass.__annotations__[field_name] = field | |
namespace = { | |
'__annotations__': DataClass.__annotations__, | |
'__module__': DataClass.__module__, | |
'__qualname__': DataClass.__qualname__ | |
} | |
return MetaModel(DataClass.__name__, (BaseModel,), namespace) | |
return _get_model(DataClass) | |
class ModelTest(unittest.TestCase): | |
def test_simple_success(self): | |
@dataclass | |
class SimpleDataClass: | |
val: int | |
SimpleModel = from_dataclass(SimpleDataClass) | |
data = {'val': 123} | |
res = SimpleModel(**data) | |
self.assertEqual(res.dict(), data) | |
def test_simple_validation_error(self): | |
@dataclass | |
class SimpleDataClass: | |
val: int | |
SimpleModel = from_dataclass(SimpleDataClass) | |
data = {'val': 'word'} | |
with self.assertRaises(ValidationError): | |
SimpleModel(**data) | |
def test_nested_success(self): | |
@dataclass | |
class DataClassInner: | |
inner_val: str | |
@dataclass | |
class NestedDataClass: | |
val: str | |
inner_data_class: DataClassInner | |
NestedModel = from_dataclass(NestedDataClass) | |
data = {'val': 'word', 'inner_data_class': {'inner_val': 'word'}} | |
res = NestedModel(**data) | |
self.assertEqual(res.dict(), data) | |
def test_nested_validation_error(self): | |
@dataclass | |
class DataClassInner: | |
inner_val: int | |
@dataclass | |
class NestedDataClass: | |
val: str | |
inner_data_class: DataClassInner | |
NestedModel = from_dataclass(NestedDataClass) | |
data = {'val': 'word', 'inner_data_class': {'inner_val': 'word'}} | |
with self.assertRaises(ValidationError): | |
NestedModel(**data) | |
if __name__ == '__main__': | |
unittest.main(verbosity=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment