How to implement subclass-based plugin architecture? How to handle polymorphism from persisted data?
Last active
January 11, 2023 15:06
-
-
Save ivangeorgiev/454179c6f1c6ff5f22d4e3c76a302415 to your computer and use it in GitHub Desktop.
Python Subclass Factory
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
from typing import Callable, Hashable | |
NameGetter = Callable[[type], Hashable] | |
# %% | |
def get_name_from_attribute(attribute_name: str) -> NameGetter: | |
def getter(klass: type) -> Hashable: | |
return getattr(klass, attribute_name) | |
return getter | |
# %% | |
def get_subclasses(cls): | |
for subclass in cls.__subclasses__(): | |
yield from get_subclasses(subclass) | |
yield subclass | |
def subclass_factory(base_class: type, *, name_attribute: str = None, get_name: NameGetter = None): | |
"""Get subclass factory based on name attribute or name getter""" | |
def factory(subclass_name, *args, **kwargs): | |
for subclass in get_subclasses(base_class): | |
name = get_name(subclass) | |
if subclass_name == name: | |
return subclass(*args, **kwargs) | |
raise ValueError(f"Subclass with name '{subclass_name}' is not registered") | |
name_attribute = name_attribute or '__qualname__' | |
get_name = get_name or get_name_from_attribute(name_attribute) | |
return factory | |
# %% | |
class BaseClass: | |
pass | |
class Subclass(BaseClass): | |
name = 'i-am' | |
def __init__(self, *args, **kwargs): | |
self.args = args | |
self.kwargs = kwargs | |
# %% | |
args = (1, 2) | |
kwargs = { "arg": "value" } | |
# %% | |
factory_by_qualname = subclass_factory(BaseClass) | |
instance_by_qualname = factory_by_qualname('Subclass', *args, **kwargs) | |
assert isinstance(instance_by_qualname, Subclass) | |
assert instance_by_qualname.args == args, instance_by_qualname.args | |
assert instance_by_qualname.kwargs == instance_by_qualname.kwargs | |
# %% | |
factory_by_name = subclass_factory(BaseClass, name_attribute='name') | |
instance_by_name = factory_by_name('i-am', *args, **kwargs) | |
assert isinstance(instance_by_name, Subclass) | |
assert instance_by_name.args == args, instance_by_name.args | |
assert instance_by_name.kwargs == instance_by_name.kwargs | |
# %% | |
factory_by_getter = subclass_factory(BaseClass, get_name = lambda cls: f"{cls.__module__}.{cls.__qualname__}") | |
instance_by_getter = factory_by_getter('__main__.Subclass', *args, **kwargs) | |
assert isinstance(instance_by_name, Subclass) | |
assert instance_by_getter.args == args, instance_by_getter.args | |
assert instance_by_getter.kwargs == instance_by_getter.kwargs | |
# %% |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Tools to implement subclass-based plugin-architecture. | |
See https://gist.github.com/ivangeorgiev/454179c6f1c6ff5f22d4e3c76a302415 | |
""" | |
from __future__ import annotations | |
from operator import attrgetter | |
from typing import Callable, Hashable | |
CLASS_NAME_ATTRIBUTE = "__qualname__" | |
NameGetter = Callable[[type], Hashable] | |
def enumerate_subclasses(cls): | |
for subclass in cls.__subclasses__(): | |
yield from enumerate_subclasses(subclass) | |
yield subclass | |
def get_subclasses( | |
base_class: type, | |
get_name: NameGetter = None, | |
subclass_cache: dict[str, type] = None, | |
) -> dict[str, type]: | |
"""Get a dictionary with subclasses of a class""" | |
if subclass_cache is None: | |
subclass_cache = {} | |
if not subclass_cache: | |
get_name = get_name or attrgetter(CLASS_NAME_ATTRIBUTE) | |
for subclass in enumerate_subclasses(base_class): | |
subclass_cache[get_name(subclass)] = subclass | |
return subclass_cache | |
def subclass_factory( | |
base_class: type, | |
*, | |
get_name: NameGetter = None, | |
cache_subclasses: bool = False, | |
): | |
"""Create callable subclass factory based on name getter""" | |
def factory(subclass_name, *args, **kwargs): | |
if not cache_subclasses: | |
subclass_cache.clear() | |
for name, subclass in get_subclasses( | |
base_class, get_name, subclass_cache | |
).items(): | |
if subclass_name == name: | |
return subclass(*args, **kwargs) | |
raise ValueError( | |
f"'{subclass_name}' is not registered as subclass of {base_class}" | |
) | |
subclass_cache = {} | |
return factory |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
from typing import Callable, Hashable | |
NameGetter = Callable[[type], Hashable] | |
def get_name_from_attribute(attribute_name: str) -> NameGetter: | |
def getter(klass: type) -> Hashable: | |
return getattr(klass, attribute_name) | |
return getter | |
def get_name_from_class_name(): | |
return get_name_from_attribute("__qualname__") | |
def get_subclasses(klass): | |
for subclass in klass.__subclasses__(): | |
yield from get_subclasses(subclass) | |
yield subclass | |
class SubclassRegistry: | |
_subclass_cache = None | |
def __init__( | |
self, base_class: type, get_name: NameGetter, cache_subclasses: bool = False | |
): | |
self.base_class = base_class | |
self.get_name = get_name | |
self.cache_subclasses = cache_subclasses | |
@property | |
def subclasses(self): | |
if not (self._subclass_cache and self.cache_subclasses): | |
self._subclass_cache = { | |
self.get_name(klass): klass for klass in get_subclasses(self.base_class) | |
} | |
return self._subclass_cache | |
def __getitem__(self, name): | |
return self._get_klass(name) | |
def _get_klass(self, name): | |
return self.subclasses[name] | |
class SubclassFactory: | |
def __init__(self, registry: SubclassRegistry): | |
self.registry = registry | |
def __call__(self, name, *args, **kwargs): | |
klass = self.registry[name] | |
return klass(*args, **kwargs) | |
# %% | |
class BaseClass: | |
pass | |
class SubClass(BaseClass): | |
def __init__(self, *args, **kwargs) -> None: | |
self.args = args | |
self.kwargs = kwargs | |
r = SubclassRegistry( | |
BaseClass, get_name=get_name_from_class_name(), cache_subclasses=True | |
) | |
f = SubclassFactory(r) | |
# %% | |
r["SubClass"] | |
# %% | |
args = (1, 2) | |
kwargs = {"arg": "value"} | |
instance = f("SubClass", *args, **kwargs) | |
assert instance.args == args, instance.args | |
assert instance.kwargs == kwargs, instance.kwargs | |
# %% |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import inspect | |
from types import FunctionType | |
from typing import Any, Callable, Hashable, overload, TypeVar | |
T = TypeVar('T', bound=type) | |
NameGetter = Callable[[type], Hashable] | |
ClassDecorator = Callable[[T], T] | |
def get_name_from_attribute(attribute_name: str) -> NameGetter: | |
def getter(klass: type) -> Hashable: | |
return getattr(klass, attribute_name) | |
return getter | |
def get_name_from_qualname() -> NameGetter: | |
return get_name_from_attribute('__qualname__') | |
@overload | |
def registry(arg: T) -> T: | |
"""Use as @decorator""" | |
@overload | |
def registry() -> ClassDecorator: | |
"""Use as @decorator()""" | |
@overload | |
def registry(arg: str) -> ClassDecorator: | |
"""Use as @decorator(name_attribute)""" | |
@overload | |
def registry(arg: NameGetter | None = None) -> ClassDecorator: | |
"""Use as @decorator(name_getter)""" | |
# %% | |
def registry(arg: T | str | NameGetter | None = None) -> T | ClassDecorator: | |
def decorator(klass): | |
class DecoratedKlass(klass): | |
__subclasses = {} | |
def __init_subclass__(cls) -> None: | |
cls.__subclasses[get_name(cls)] = cls | |
print(cls.__subclasses) | |
return DecoratedKlass | |
# used as @decorator | |
if inspect.isclass(arg): | |
get_name = get_name_from_qualname() | |
print('used as @decorator with getter ', get_name) | |
return decorator(arg) | |
# used as @decorator() | |
if arg is None: | |
get_name = get_name_from_qualname() | |
print('used as @decorator() with getter ', get_name) | |
return decorator | |
# used as @decorator(name_attribute) | |
if isinstance(arg, str): | |
get_name = get_name_from_attribute(arg) | |
print('used as @decorator(name_attribute) with getter ', get_name) | |
return decorator | |
# used as @decorator(name_getter) | |
if isinstance(arg, FunctionType): | |
get_name = arg | |
print('used as @decorator(name_getter) with getter ', get_name) | |
return decorator | |
raise TypeError(f"Expected type, str, FunctionType or None, but {type(arg)} given") | |
# %% | |
@registry | |
class BaseClassWithoutArgList: | |
pass | |
class SubClassWithoutArgList(BaseClassWithoutArgList): | |
pass | |
# %% | |
@registry() | |
class BaseClassEmptyArgList: | |
pass | |
class SubClassEmptyArgList(BaseClassEmptyArgList): | |
pass | |
# %% | |
@registry("name") | |
class BaseClassWithNameAttr: | |
pass | |
class SubClassWithNameAttr(BaseClassWithNameAttr): | |
name = "i-am-child" | |
pass | |
# %% | |
def get_name(klass: type): | |
module_name = klass.__module__ + "." | |
return module_name + klass.__qualname__ | |
@registry(get_name) | |
class BaseClassWithNameGetter: | |
pass | |
class SubClassWithNameGetter(BaseClassWithNameGetter): | |
pass | |
# %% |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from operator import attrgetter | |
import pytest | |
from subclass_factory_v2 import enumerate_subclasses, get_subclasses, subclass_factory | |
class BaseClass: | |
def __init__(self, *args, **kwargs): | |
self.args = args | |
self.kwargs = kwargs | |
class SubClass(BaseClass): | |
the_name = "sub-class" | |
class SubSubClass(SubClass): | |
the_name = "sub-sub-class" | |
class TestEnumerateSubclasses: | |
def test_should_yield_subclasses(self): | |
# Given | |
# When | |
actual = list(enumerate_subclasses(BaseClass)) | |
# Then | |
assert actual == [SubSubClass, SubClass] | |
class TestGetSubclasses: | |
def test_should_return_dictionary_of_subclasses_using_class_name_no_get_name_given(self): | |
# Given | |
# When | |
actual = get_subclasses(BaseClass) | |
assert actual == { 'SubSubClass': SubSubClass, 'SubClass': SubClass } | |
def test_should_return_dictionary_of_subclasses_using_passed_get_name(self): | |
# Given | |
# When | |
actual = get_subclasses(BaseClass, get_name=attrgetter('the_name')) | |
assert actual == { 'sub-sub-class': SubSubClass, 'sub-class': SubClass } | |
def test_should_return_passed_cache(self): | |
# Given | |
subclass_cache = {} | |
# When | |
actual = get_subclasses(BaseClass, subclass_cache=subclass_cache) | |
assert actual is subclass_cache | |
def test_should_refresh_passed_cache_if_empty(self): | |
# Given | |
subclass_cache = {} | |
# When | |
actual = get_subclasses(BaseClass, subclass_cache) | |
assert actual == { 'SubSubClass': SubSubClass, 'SubClass': SubClass } | |
def test_should_use_passed_cache_if_not_empty(self): | |
# Given | |
subclass_cache = {'a': 'b'} | |
# When | |
actual = get_subclasses(BaseClass, subclass_cache=subclass_cache) | |
assert actual == {'a': 'b'} | |
class TestSubclssFactory: | |
def test_should_create_instances_using_class_name_no_get_name_given(self): | |
# Given | |
factory = subclass_factory(BaseClass) | |
# When | |
instance = factory('SubClass', 1, 2, x=3) | |
assert isinstance(instance, SubClass) | |
assert instance.args == (1, 2) | |
assert instance.kwargs == { 'x': 3 } | |
def test_should_create_instances_using_name_getter(self): | |
# Given | |
factory = subclass_factory(BaseClass, get_name=attrgetter('the_name')) | |
# When | |
instance = factory('sub-class', 1, 2, x=3) | |
assert isinstance(instance, SubClass) | |
assert instance.args == (1, 2) | |
assert instance.kwargs == { 'x': 3 } | |
def test_should_raise_ValueError_if_subclss_is_not_registered(self): | |
# Given | |
factory = subclass_factory(BaseClass) | |
# When | |
# Then | |
with pytest.raises(ValueError): | |
factory('You-Do-Not-Know-Me') | |
def test_should_cache_subclasses_if_requested(self): | |
# Given | |
factory = subclass_factory(BaseClass, cache_subclasses=True) | |
factory('SubClass') | |
# When | |
class UnregisteredSubclassBecauseSubclassesCachedOnCallBeforeSubclassDefinition(BaseClass): | |
the_name = 'nope' | |
# Then | |
with pytest.raises(ValueError): | |
factory('UnregisteredSubclassBecauseSubclassesCachedOnCallBeforeSubclassDefinition') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment