Created
October 19, 2023 20:39
-
-
Save cbonesana/e88106404b85febb0122d33a535a83cf to your computer and use it in GitHub Desktop.
Pydantic and how to manage subclasses
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any | |
from pydantic import BaseModel, root_validator | |
import json | |
# this dictionary will collect all classes constructors | |
class_registry: dict[str, Any] = dict() | |
# base class | |
class Creature(BaseModel): | |
# this field will be used to keep track of the real type | |
_name: str = "" | |
# register subclasses when created | |
def __init_subclass__(cls, **kwargs: Any) -> None: | |
super().__init_subclass__(**kwargs) | |
class_registry[cls.__name__] = cls | |
# set the _name field with the correct class-name | |
@root_validator(pre=False) | |
def set_creature_type(cls, values): | |
values["_name"] = cls.__name__ | |
return values | |
# register also the base class | |
class_registry[Creature.__name__] = Creature | |
# define subclasses | |
class Gargoyle(Creature): | |
... | |
class Gnome(Creature): | |
... | |
# this will also work with deep hierarchies | |
class Troll(Gnome): | |
... | |
# let's create a collection of base-class | |
class Forest(Creature): | |
creatures: list[Creature] | |
# this method is to correctly instantiate the classes based on | |
# their _name field and using the class_registry | |
@root_validator(pre=True) | |
def set_creatures(cls, values): | |
for index, creature in enumerate(values["creatures"]): | |
if isinstance(creature, dict): | |
class_name = creature["_name"] | |
subclass = class_registry[class_name] | |
values["creatures"][index] = subclass(**creature) | |
return values | |
if __name__ == "__main__": | |
# add some creatures | |
c = Creature() | |
g = Gargoyle() | |
t = Troll() | |
u = Gnome() | |
# create the collection | |
f = Forest(creatures=[c, g, t, u]) | |
# everything is fine | |
print(f) | |
# let's deserialize to JSON and reload the content in a new class | |
f2 = Forest(**json.loads(f.json())) | |
print(f2) | |
# magic! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment