Skip to content

Instantly share code, notes, and snippets.

@NixonInnes
Created October 21, 2024 23:31
Show Gist options
  • Save NixonInnes/6b7cf851cd2715aafbf32c4fb54405ac to your computer and use it in GitHub Desktop.
Save NixonInnes/6b7cf851cd2715aafbf32c4fb54405ac to your computer and use it in GitHub Desktop.
Python Services Container
from abc import ABC, abstractmethod
import logging
from enum import Enum
import inspect
import threading
from typing import Any, Callable, get_type_hints, override, Self
from functools import partial
from ssdi.exceptions import (
ServicesCircularDependencyError,
ServicesResolutionError,
ServicesRegistrationError,
)
class ServiceLife(Enum):
SINGLETON = 1
TRANSIENT = 2
class ServiceEntry[T]:
service_class: type[T]
service_life: ServiceLife
instance: T | None
def __init__(self, service_class: type[T], service_life: ServiceLife):
self.service_class = service_class
self.service_life = service_life
self.instance = None
self.lock = threading.Lock()
class IServiceContainer(ABC):
@abstractmethod
def register[T](
self,
abstract_class: type[T],
service_class: type[T],
service_life: ServiceLife = ServiceLife.TRANSIENT,
) -> None: ...
@abstractmethod
def get[T](self, abstract_class: type[T], **overrides: Any) -> T: ... # pyright: ignore[reportAny]
def __getitem__[T](self, key: type[T]) -> Callable[..., T]:
return partial(self.get, key)
class ServiceContainer(IServiceContainer):
"""
A container for registering and resolving services with dependency injection.
"""
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
self.__lock = threading.Lock()
self.__services: dict[type[Any], ServiceEntry[Any]] = {}
self.__resolution_stack = threading.local()
this_service_entry = ServiceEntry[Self](
self.__class__,
ServiceLife.SINGLETON,
)
this_service_entry.instance = self
self.__services[IServiceContainer] = this_service_entry
self.logger.debug("Initialised service container")
def _set_service[T](self, key: type[T], value: ServiceEntry[T]) -> None:
with self.__lock:
self.__services[key] = value
def _get_service[T](self, key: type[T]) -> ServiceEntry[T]:
with self.__lock:
return self.__services[key]
@property
def stack(self):
if not hasattr(self.__resolution_stack, 'stack'):
self.__resolution_stack.stack = set()
self.logger.debug("Initialized a new resolution stack for the thread.")
return self.__resolution_stack.stack
@override
def register[T](
self,
abstract_class: type[T],
service_class: type[T],
service_life: ServiceLife = ServiceLife.TRANSIENT,
):
"""
Registers a service class with the container.
:param abstract_class: The abstract class or interface.
:param service class: The concrete implementation class.
:param service_life: The lifecycle of the service.
"""
if abstract_class in self.__services:
raise ServicesRegistrationError(
f"Service '{abstract_class.__name__}' is already registered."
)
assert issubclass(
service_class, abstract_class
), f"'{service_class.__name__}' does not implement '{abstract_class.__name__}'"
service = ServiceEntry[T](service_class, service_life)
self._set_service(abstract_class, service)
@override
def get[T](self, abstract_class: type[T], **overrides: Any) -> T: # pyright: ignore[reportAny]
if abstract_class in self.stack:
raise ServicesCircularDependencyError(
f"Circular dependency detected for '{abstract_class.__name__}'"
)
self.stack.add(abstract_class)
try:
service = self._get_service(abstract_class)
if service is None:
raise ServicesResolutionError(
f"No service of type '{abstract_class}' has been registered."
)
with service.lock:
if (
service.service_life == ServiceLife.SINGLETON
and service.instance is not None
):
if overrides:
raise ServicesResolutionError(
f"Service '{abstract_class.__name__}' is a singleton and already instanciated and overrides specified!"
)
return service.instance
instance = self._create_instance(service, overrides)
if service.service_life == ServiceLife.SINGLETON:
service.instance = instance
return instance
finally:
self.__resolution_stack.stack.remove(abstract_class)
def is_registered(self, abstract_class: type) -> bool:
return abstract_class in self.__services
def _create_instance[T](self, service_entry: ServiceEntry[T], overrides: dict[str, Any]) -> T:
signature = inspect.signature(service_entry.service_class.__init__)
type_hints = get_type_hints(service_entry.service_class.__init__)
kwargs = {}
for name, param in signature.parameters.items():
# Skip 'self' parameter
if name == "self":
continue
# Skip *args and **kwargs
if param.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
continue
if name in overrides:
kwargs[name] = overrides[name]
else:
param_type = type_hints.get(name)
if param_type is None:
if param.default != inspect.Parameter.empty:
kwargs[name] = param.default
continue
raise ServicesResolutionError(
f"No type annotation for parameter '{name}' in '{service_entry.service_class.__name__}'"
)
if param_type in self.__services:
kwargs[name] = self.get(param_type)
else:
if param.default != inspect.Parameter.empty:
kwargs[name] = param.default
else:
raise ServicesResolutionError(
f"Can't resolve dependency '{name}' of type '{param_type}' for service '{service_entry.service_class.__name__}'"
)
return service_entry.service_class(**kwargs)
@override
def __getitem__[T](self, key: type[T]) -> Callable[..., T]:
return partial(self.get, key)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment