Created
November 8, 2022 22:10
-
-
Save ahancock1/3bd423248cd5bba2e25bd417ddac0f43 to your computer and use it in GitHub Desktop.
inversion of control dependency injection service container
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from typing import Callable, Generic, Protocol, TypeVar | |
from typing import overload | |
from typing import get_type_hints, get_args, get_origin | |
T = TypeVar("T") | |
class IServiceProvider(Protocol): | |
def get_service(self, resolve_type: type[T], service_name: str = None) -> T | None: | |
... | |
class IServiceFactory(Protocol[T]): | |
def get_instance(self, provider: IServiceProvider) -> T: | |
... | |
class ServiceProvider: | |
_services: dict[tuple[type, str], list[IServiceFactory]] | |
def __init__(self, services: dict[tuple[type, str], list[IServiceFactory]]) -> None: | |
services[(IServiceProvider, None,)] = self | |
self._services = services | |
def _resolve_services(self, | |
resolve_type: type[T], | |
type_origin: type[list | tuple | set]) -> T: | |
result = type_origin() | |
for type_arg in get_args(resolve_type): | |
if type_arg in [list, tuple, set]: | |
services = self.get_service(type_arg) | |
if not services: | |
continue | |
result.extend(services) | |
else: | |
service_key = (type_arg, None,) | |
factories = self._services.get(service_key, []) | |
for factory in factories: | |
service = factory.get_instance(self) | |
if service is None: | |
continue | |
result.append(service) | |
return result | |
def _resolve_service(self, resolve_type: type[T], service_name: str = None) -> T: | |
service_key = (resolve_type, service_name,) | |
if service_key not in self._services: | |
return None | |
factory = self._services[service_key][0] | |
return factory.get_instance(self) | |
def get_service(self, resolve_type: type[T], service_name: str = None) -> T | None: | |
type_origin = get_origin(resolve_type) or resolve_type | |
match type_origin: | |
case _ as x if x in [list, tuple, set]: | |
return self._resolve_services( | |
resolve_type, | |
type_origin) | |
case _: | |
return self._resolve_service( | |
resolve_type, | |
service_name) | |
def default_factory(service_type: type[T]) -> Callable[[IServiceProvider], T]: | |
def _(provider: IServiceProvider) -> T: | |
type_hints = get_type_hints(service_type.__init__) | |
kwargs = {} | |
for hint_name, hint_type in type_hints.items(): | |
if hint_name == "return": | |
continue | |
kwargs[hint_name] = provider.get_service(hint_type) | |
return service_type(**kwargs) | |
return _ | |
class Transient(Generic[T]): | |
_factory: Callable[[IServiceProvider], T] | |
def __init__(self, | |
factory: Callable[[IServiceProvider], T] = None) -> None: | |
self._factory = factory | |
def get_instance(self, services: IServiceProvider) -> T: | |
return self._factory(services) | |
class Singleton(Generic[T]): | |
_instance: T | |
_factory: Callable[[IServiceProvider], T] | |
def __init__(self, | |
factory: Callable[[IServiceProvider], T] = None) -> None: | |
self._instance = None | |
self._factory = factory | |
def get_instance(self, services: IServiceProvider) -> T: | |
if self._instance is None: | |
self._instance = self._factory(services) | |
return self._instance | |
class IServiceContainer(Protocol): | |
@overload | |
def add_singleton(self, | |
resolve_type: type[T], | |
service_type: type[T]) -> None: ... | |
@overload | |
def add_singleton(self, | |
resolve_type: type[T], | |
service_type: type[T], | |
service_name: str) -> None: ... | |
@overload | |
def add_singleton(self, | |
resolve_type: type[T], | |
service_factory: Callable[[IServiceProvider], T]) -> None: ... | |
@overload | |
def add_singleton(self, | |
resolve_type: type[T], | |
service_name: str, | |
service_factory: Callable[[IServiceProvider], T]) -> None: ... | |
def add_singleton(self, | |
resolve_type: type[T], | |
service_type: type[T] = None, | |
service_name: str = None, | |
service_factory: Callable[[IServiceProvider], T] = None) -> None: | |
... | |
@overload | |
def add_transient(self, | |
resolve_type: type[T], | |
service_type: type[T]) -> None: ... | |
@overload | |
def add_transient(self, | |
resolve_type: type[T], | |
service_type: type[T], | |
service_name: str) -> None: ... | |
@overload | |
def add_transient(self, | |
resolve_type: type[T], | |
service_factory: Callable[[IServiceProvider], T]) -> None: ... | |
@overload | |
def add_transient(self, | |
resolve_type: type[T], | |
service_name: str, | |
service_factory: Callable[[IServiceProvider], T]) -> None: ... | |
def add_transient(self, | |
resolve_type: type[T], | |
service_type: type[T] = None, | |
service_name: str = None, | |
service_factory: Callable[[IServiceProvider], T] = None) -> None: | |
... | |
def build(self) -> IServiceProvider: | |
... | |
class ServiceContainer: | |
_services: dict[tuple[type, str], list[IServiceFactory]] | |
def __init__(self) -> None: | |
self._services = {} | |
@overload | |
def add_singleton(self, | |
resolve_type: type[T], | |
service_type: type[T]) -> None: ... | |
@overload | |
def add_singleton(self, | |
resolve_type: type[T], | |
service_type: type[T], | |
service_name: str) -> None: ... | |
@overload | |
def add_singleton(self, | |
resolve_type: type[T], | |
service_factory: Callable[[IServiceProvider], T]) -> None: ... | |
@overload | |
def add_singleton(self, | |
resolve_type: type[T], | |
service_name: str, | |
service_factory: Callable[[IServiceProvider], T]) -> None: ... | |
def add_singleton(self, | |
resolve_type: type[T], | |
service_type: type[T] = None, | |
service_name: str = None, | |
service_factory: Callable[[IServiceProvider], T] = None) -> None: | |
self._register( | |
resolve_type, service_name, | |
Singleton( | |
service_factory or default_factory(service_type or resolve_type) | |
)) | |
@overload | |
def add_transient(self, | |
resolve_type: type[T], | |
service_type: type[T]) -> None: ... | |
@overload | |
def add_transient(self, | |
resolve_type: type[T], | |
service_type: type[T], | |
service_name: str) -> None: ... | |
@overload | |
def add_transient(self, | |
resolve_type: type[T], | |
service_factory: Callable[[IServiceProvider], T]) -> None: ... | |
@overload | |
def add_transient(self, | |
resolve_type: type[T], | |
service_name: str, | |
service_factory: Callable[[IServiceProvider], T]) -> None: ... | |
def add_transient(self, | |
resolve_type: type[T], | |
service_type: type[T] = None, | |
service_name: str = None, | |
service_factory: Callable[[IServiceProvider], T] = None) -> None: | |
self._register( | |
resolve_type, service_name, | |
Transient( | |
service_factory or default_factory(service_type or resolve_type) | |
)) | |
def _register(self, | |
resolve_type: type[T], | |
service_name: str, | |
factory: IServiceFactory) -> None: | |
key = (resolve_type, service_name) | |
if key not in self._services or service_name is not None: | |
self._services[key] = [factory] | |
else: | |
self._services[key].append(factory) | |
def build(self) -> ServiceProvider: | |
return ServiceProvider(self._services) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment