Skip to content

Instantly share code, notes, and snippets.

@lrhache
Last active June 14, 2021 16:16
Show Gist options
  • Save lrhache/1b54d70b5481a64afddd52b3ee081153 to your computer and use it in GitHub Desktop.
Save lrhache/1b54d70b5481a64afddd52b3ee081153 to your computer and use it in GitHub Desktop.
Metaclass examples
from collections import defaultdict
from typing import Type, cast, Callable, TypeVar, Any
from functools import wraps
T = TypeVar('T')
def wrap_setattr(definition: dict, fn: Callable) -> Callable:
@wraps(fn)
def wrapped(self, attr, value):
if attr not in definition['annotations']:
raise AttributeError(f'`{attr}` not found in model')
if type(value) != (attr_type := definition['annotations'][attr]):
raise TypeError(f'`{attr}` should be of type `{attr_type}`')
return fn(self, attr, value)
return wrapped
def get_fields(cls) -> list:
return MetaBaseModel.get_model_definition(cls.__name__)['annotations']
def represent(self) -> str:
fields = self.get_fields()
string = []
for field in fields:
value = getattr(self, field, None)
string.append(f'{field}:"{value}"')
return '{' + ', '.join(string) + '}'
class Objects:
_mock_db = defaultdict(list)
@property
def model_name(self) -> str:
return self.model['name']
def __init__(self, model: 'BaseModel'):
self.model = model
def get(self, id: int = None) -> 'BaseModel':
# this could be a call to the database
matches = [o for o in self._mock_db[self.model_name] if o.id == id]
if len(matches) == 0:
raise ObjectNotFound()
if len(matches) > 1:
raise TooManyResults()
return matches[0]
def create(self, **kwargs):
obj = self.model['cls'](**kwargs)
# this could be a call to the database
self._mock_db[self.model_name].append(obj)
return obj
def all(self):
return self._mock_db[self.model_name]
class ObjectNotFound(Exception):
...
class TooManyResults(Exception):
...
class MetaBaseModel(type):
definitions = {}
def __new__(metacls: Type['MetaBaseModel'], clsname: str, bases: tuple, classdict: dict) -> 'MetaBaseModel':
classdict['objects'] = None
cls = super().__new__(metacls, clsname, bases, classdict)
base_annotations = {}
for base in bases:
# `bases` is a tuple with all the parent classes
base_annotations.update(
MetaBaseModel.get_model_definition(base.__name__)['annotations']
)
base_annotations.update(classdict.get('__annotations__', {}))
definition = {
'cls': cls,
'name': clsname,
'is_base': len(MetaBaseModel.definitions) == 0,
'annotations': base_annotations
}
MetaBaseModel.definitions[clsname] = definition
if not definition['is_base']:
__setattr__ = getattr(cls, '__setattr__')
setattr(cls, '__setattr__', wrap_setattr(definition, __setattr__))
# create a property
setattr(cls, 'objects', Objects(definition))
# create a class method for each of the models
# you could also use the `staticmethod` decorator
setattr(cls, 'get_fields', classmethod(get_fields))
# add a repr method to each models
setattr(cls, '__repr__', represent)
return cast(MetaBaseModel, cls)
def __call__(cls, *args: list, **kwargs: dict) -> T:
definition = MetaBaseModel.get_model_definition(cls.__name__)
instance = super().__call__()
for k, value in kwargs.items():
if k in definition['annotations']:
# set all attribute for the model with the values passed to __init__
setattr(instance, k, value)
return instance
def __setattr__(cls, attr: str, value: Any, /) -> Any:
annotations = MetaBaseModel.definitions[cls.__name__]['annotations']
if annotations.get(attr):
attr_type = annotations[attr]
if type(value) != attr_type:
raise ValueError(f"`{attr}` must be of type {attr_type}")
return super().__setattr__(attr, value)
@classmethod
def get_model_definition(cls, model_name: str, /) -> dict:
return MetaBaseModel.definitions.get(model_name)
@classmethod
def get_all_models(cls):
return {
c['cls'].__name__: c['cls'] for c in MetaBaseModel.definitions.values()
if not c['is_base']
}
class BaseModel(metaclass=MetaBaseModel):
id: int = 0
class Person(BaseModel):
name: str
class Person2(Person):
# overwrite type
id: str
class Address(BaseModel):
person: Person
street: str
print("ALL MODELS:", MetaBaseModel.get_all_models())
print("\n-----\n")
print('Person model fields', Person.get_fields())
print('Person2 model fields', Person2.get_fields())
print('Address model fields', Address.get_fields())
print("\n-----\n")
try:
Person.id = "asd"
except ValueError as e:
print("ERROR:", e)
Person2.id = "asd"
person = Person()
try:
person.name = 0
except TypeError as e:
print("ERROR:", e)
try:
person.attribute_that_dont_exist = True
except AttributeError as e:
print("ERROR:", e)
print("\n-----\n")
person.name = "Louis"
address = Address()
try:
address.person = "something"
except TypeError as e:
print("ERROR:", e)
print("\n-----\n")
address.person = person
address.street = "123 meta street"
print(address)
person_obj = Person(name='louis')
print("person object with init method:", person_obj.name)
print("\n-----\n")
louis2 = Person.objects.create(name='Louis 2', id=100)
louis3 = Person.objects.create(name='Louis 3', id=101)
address1 = Address.objects.create(person=louis2, id=200, street="test")
print(address1)
print("\n-----\n")
try:
address2 = Address.objects.create(person="louis", id=201, street="test")
except TypeError as e:
print("ERROR:", e)
print("\n-----\n")
persons = Person.objects.all()
print('All persons:', persons)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment