Last active
June 14, 2021 16:16
-
-
Save lrhache/1b54d70b5481a64afddd52b3ee081153 to your computer and use it in GitHub Desktop.
Metaclass examples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
from typing import Type, cast, Callable, TypeVar, Any | |
from functools import wraps | |
T = TypeVar('T') | |
def wrap_setattr(definition: dict, fn: Callable) -> Callable: | |
@wraps(fn) | |
def wrapped(self, attr, value): | |
if attr not in definition['annotations']: | |
raise AttributeError(f'`{attr}` not found in model') | |
if type(value) != (attr_type := definition['annotations'][attr]): | |
raise TypeError(f'`{attr}` should be of type `{attr_type}`') | |
return fn(self, attr, value) | |
return wrapped | |
def get_fields(cls) -> list: | |
return MetaBaseModel.get_model_definition(cls.__name__)['annotations'] | |
def represent(self) -> str: | |
fields = self.get_fields() | |
string = [] | |
for field in fields: | |
value = getattr(self, field, None) | |
string.append(f'{field}:"{value}"') | |
return '{' + ', '.join(string) + '}' | |
class Objects: | |
_mock_db = defaultdict(list) | |
@property | |
def model_name(self) -> str: | |
return self.model['name'] | |
def __init__(self, model: 'BaseModel'): | |
self.model = model | |
def get(self, id: int = None) -> 'BaseModel': | |
# this could be a call to the database | |
matches = [o for o in self._mock_db[self.model_name] if o.id == id] | |
if len(matches) == 0: | |
raise ObjectNotFound() | |
if len(matches) > 1: | |
raise TooManyResults() | |
return matches[0] | |
def create(self, **kwargs): | |
obj = self.model['cls'](**kwargs) | |
# this could be a call to the database | |
self._mock_db[self.model_name].append(obj) | |
return obj | |
def all(self): | |
return self._mock_db[self.model_name] | |
class ObjectNotFound(Exception): | |
... | |
class TooManyResults(Exception): | |
... | |
class MetaBaseModel(type): | |
definitions = {} | |
def __new__(metacls: Type['MetaBaseModel'], clsname: str, bases: tuple, classdict: dict) -> 'MetaBaseModel': | |
classdict['objects'] = None | |
cls = super().__new__(metacls, clsname, bases, classdict) | |
base_annotations = {} | |
for base in bases: | |
# `bases` is a tuple with all the parent classes | |
base_annotations.update( | |
MetaBaseModel.get_model_definition(base.__name__)['annotations'] | |
) | |
base_annotations.update(classdict.get('__annotations__', {})) | |
definition = { | |
'cls': cls, | |
'name': clsname, | |
'is_base': len(MetaBaseModel.definitions) == 0, | |
'annotations': base_annotations | |
} | |
MetaBaseModel.definitions[clsname] = definition | |
if not definition['is_base']: | |
__setattr__ = getattr(cls, '__setattr__') | |
setattr(cls, '__setattr__', wrap_setattr(definition, __setattr__)) | |
# create a property | |
setattr(cls, 'objects', Objects(definition)) | |
# create a class method for each of the models | |
# you could also use the `staticmethod` decorator | |
setattr(cls, 'get_fields', classmethod(get_fields)) | |
# add a repr method to each models | |
setattr(cls, '__repr__', represent) | |
return cast(MetaBaseModel, cls) | |
def __call__(cls, *args: list, **kwargs: dict) -> T: | |
definition = MetaBaseModel.get_model_definition(cls.__name__) | |
instance = super().__call__() | |
for k, value in kwargs.items(): | |
if k in definition['annotations']: | |
# set all attribute for the model with the values passed to __init__ | |
setattr(instance, k, value) | |
return instance | |
def __setattr__(cls, attr: str, value: Any, /) -> Any: | |
annotations = MetaBaseModel.definitions[cls.__name__]['annotations'] | |
if annotations.get(attr): | |
attr_type = annotations[attr] | |
if type(value) != attr_type: | |
raise ValueError(f"`{attr}` must be of type {attr_type}") | |
return super().__setattr__(attr, value) | |
@classmethod | |
def get_model_definition(cls, model_name: str, /) -> dict: | |
return MetaBaseModel.definitions.get(model_name) | |
@classmethod | |
def get_all_models(cls): | |
return { | |
c['cls'].__name__: c['cls'] for c in MetaBaseModel.definitions.values() | |
if not c['is_base'] | |
} | |
class BaseModel(metaclass=MetaBaseModel): | |
id: int = 0 | |
class Person(BaseModel): | |
name: str | |
class Person2(Person): | |
# overwrite type | |
id: str | |
class Address(BaseModel): | |
person: Person | |
street: str | |
print("ALL MODELS:", MetaBaseModel.get_all_models()) | |
print("\n-----\n") | |
print('Person model fields', Person.get_fields()) | |
print('Person2 model fields', Person2.get_fields()) | |
print('Address model fields', Address.get_fields()) | |
print("\n-----\n") | |
try: | |
Person.id = "asd" | |
except ValueError as e: | |
print("ERROR:", e) | |
Person2.id = "asd" | |
person = Person() | |
try: | |
person.name = 0 | |
except TypeError as e: | |
print("ERROR:", e) | |
try: | |
person.attribute_that_dont_exist = True | |
except AttributeError as e: | |
print("ERROR:", e) | |
print("\n-----\n") | |
person.name = "Louis" | |
address = Address() | |
try: | |
address.person = "something" | |
except TypeError as e: | |
print("ERROR:", e) | |
print("\n-----\n") | |
address.person = person | |
address.street = "123 meta street" | |
print(address) | |
person_obj = Person(name='louis') | |
print("person object with init method:", person_obj.name) | |
print("\n-----\n") | |
louis2 = Person.objects.create(name='Louis 2', id=100) | |
louis3 = Person.objects.create(name='Louis 3', id=101) | |
address1 = Address.objects.create(person=louis2, id=200, street="test") | |
print(address1) | |
print("\n-----\n") | |
try: | |
address2 = Address.objects.create(person="louis", id=201, street="test") | |
except TypeError as e: | |
print("ERROR:", e) | |
print("\n-----\n") | |
persons = Person.objects.all() | |
print('All persons:', persons) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment