Last active
October 11, 2024 16:37
-
-
Save johnhungerford/ccb398b666fd72e69f6798921383cb3f to your computer and use it in GitHub Desktop.
Dependency injection for Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This file contains a POC dependency-injection framework based on Scala's ZIO library. The idea is that | |
each dependency can have one or more "layer," which specifies how to construct that type of dependency | |
from other dependencies. These layers can be added to an "environment," which can construct a value | |
for any type that you query, so long as layers have been added for it and and its upstream dependencies. | |
Toward the end of the file there's a simple example demonstrating how this framework can be used for | |
configuring an application. | |
""" | |
from dataclasses import dataclass, is_dataclass, fields | |
from typing import Iterable, Iterator, ContextManager, TextIO, Tuple, Type, ForwardRef, get_origin, get_args, List, Any, Callable, ClassVar, Dict, Generic, Mapping, MutableMapping, Optional, Self, Type, TypeVar | |
import abc | |
import inspect | |
import sys | |
from contextlib import contextmanager, ExitStack | |
######################################## | |
## ## | |
## DEPENDENCY INJECTION FRAMEWORK ## | |
## ## | |
######################################## | |
A = TypeVar('A') | |
class Injectable(abc.ABC): | |
""" | |
If a dataclass inherits this, it can be added to an Env instance as a dependency. Any | |
Injectable subclasses added to an environment will be accessible as the highest supertype | |
that inherits Injectable. The idea is that you can choose which implementation of an | |
interface (i.e., base class) to provide to your environment, and it will be accessible as | |
that interface. | |
""" | |
injected_class: ClassVar[Type['Injectable']] | |
@classmethod | |
def construct(cls, env: 'Env') -> ContextManager[Self]: | |
if is_dataclass(cls): | |
@contextmanager | |
def ctx(): | |
with env.multi_context(*(f.type for f in fields(cls))) as params: | |
yield cls(*params) | |
return ctx() | |
raise NotImplementedError( | |
(f'Cannot generate an injectable constructor for {cls.__name__} because it is not a dataclass. Your options are:' | |
f'\n 1. Make {cls.__name__} a dataclass by adding @dataclasses.dataclass above its definition' | |
f'\n 2. Add an implementation of the classmethod "construct" on {cls.__name__}' | |
f'\n 3. If {cls.__name__} is a base/abstract class, do not add it to an environment. Add an implementation instead.') | |
) | |
@classmethod | |
def __get_highest_class(cls) -> Type['Injectable']: | |
if hasattr(cls, 'injected_class') and cls.injected_class is not None: | |
return cls.injected_class | |
found_injectable: Optional[Type[Injectable]] = None | |
for typ in cls.__bases__: | |
if typ is Injectable: | |
return cls | |
if inspect.isclass(typ) and issubclass(typ, Injectable): | |
found_injectable = typ | |
if found_injectable is not None: | |
return found_injectable.__get_highest_class() | |
raise ValueError(f'Could not find highest subclass of Injectable from {cls.__name__}') | |
@classmethod | |
def layer(cls) -> 'Layer[Self]': | |
return Layer(cls.__get_highest_class(), cls, cls.construct) | |
@dataclass(frozen=True) | |
class _EnvValue: | |
"""Wrapper around a value to be used by Env""" | |
value: Any | |
@dataclass(frozen=True) | |
class _EnvLayer: | |
"""Wrapper around a layer to be used by Env""" | |
specific_typ: Type[Any] | |
function: Callable[['Env'], Any] | |
class Env: | |
__env_map: MutableMapping[Type[Any], _EnvValue | _EnvLayer] | |
__get_stack: List[Tuple[Type[Any], Optional[Type[Any]]]] | |
def __init__(self, *layers: 'Layer' | Type[Injectable]): | |
self.__env_map = {} | |
self.__get_stack = [] | |
for layer in layers: | |
if isinstance(layer, Layer): | |
self.add_layer(layer) | |
elif issubclass(layer, Injectable): | |
self.add_layer(layer.layer()) | |
def add_value(self, typ: Type[A], value: A): | |
self.__env_map[typ] = _EnvValue(value) | |
return self | |
def add_layer(self, layer: 'Layer'): | |
self.__env_map[layer.interface] = _EnvLayer(layer.implementation, layer.fn) | |
return self | |
@contextmanager | |
def __get(self, typ: Type[A]) -> Iterator[Optional[A]]: | |
self.__get_stack.append((typ, None)) | |
try: | |
match self.__env_map.get(typ): | |
case None: | |
yield None | |
case _EnvValue(value): | |
del self.__get_stack[-1] | |
yield value | |
case _EnvLayer(specific_typ, function): | |
self.__get_stack = [*self.__get_stack[:-1], (self.__get_stack[-1][0], specific_typ)] | |
with function(self) as value: | |
self.__env_map = {**self.__env_map, typ: _EnvValue(value)} | |
assert(any(isinstance(v, _EnvValue) for v in self.__env_map.values())) | |
del self.__get_stack[-1] | |
try: | |
yield value | |
finally: | |
self.__env_map = {**self.__env_map, typ: _EnvLayer(specific_typ, function)} | |
case _: | |
yield None | |
finally: | |
self.__get_stack = [] | |
def __error_message(self) -> str: | |
error_message = '' | |
last_value: Optional[str] = None | |
for i, (stack_typ, spec_stack_typ) in enumerate(self.__get_stack): | |
if i == len(self.__get_stack) - 1: | |
error_message = error_message + f'\n{stack_typ.__name__} <- not provided!' | |
if last_value is not None: | |
error_message = error_message + f' (required by {last_value})' | |
elif i == 0: | |
error_message = f'Unable to construct {stack_typ.__name__} from environment' | |
else: | |
error_message = error_message + f'\n{stack_typ.__name__} <- required by {last_value}' | |
last_value = stack_typ.__name__ | |
if spec_stack_typ is not None and spec_stack_typ is not stack_typ: | |
error_message = error_message + f'\n{spec_stack_typ.__name__} <- provided subtype of {stack_typ.__name__}' | |
last_value = spec_stack_typ.__name__ | |
return error_message | |
def __getitem__(self, typ: Type[A]) -> A: | |
value: Optional[A] = self.__get(typ).__enter__() | |
if value is None: | |
raise ValueError(self.__error_message()) | |
return value | |
@contextmanager | |
def context(self, typ: Type[A]) -> Iterator[A]: | |
""" | |
Retrieve a dependency of type A as a context,. | |
E.g.: | |
with my_env.context(SomeService) as some_service: | |
some_service.do_something() | |
""" | |
with self.__get(typ) as value: | |
if value is None: | |
raise ValueError(self.__error_message()) | |
yield value | |
@contextmanager | |
def multi_context(self, *types: Type[Any]) -> Iterator[Tuple[Any]]: | |
""" | |
Retrieve multiple dependencies in tuple form as a context. | |
E.g.: | |
with my_env.multi_context(ServiceA, ServiceB) as service_a, service_b: | |
a_val = service_a.calculate_something() | |
service_b.do_something(a_val) | |
""" | |
with ExitStack() as stack: | |
values = tuple(stack.enter_context(self.context(t)) for t in types) | |
yield values | |
@dataclass(frozen=True) | |
class Layer(Generic[A]): | |
""" | |
A dependency constructor, that tracks the supertype ("interface") that the dependency should | |
be injectable as, as well as the specific type ("implementation") of the dependency. | |
The constructor pulls any upstream dependencies from an environment, and generates | |
a context (resource). | |
""" | |
interface: Type[A] | |
implementation: Type[A] | |
fn: Callable[[Env], ContextManager[A]] | |
@classmethod | |
def simple(cls, typ: Type[A], fn: Callable[[Env], ContextManager[A]]) -> 'Layer[A]': | |
""" | |
Generate a simple layer where there's no distinction between the interface and | |
implementation. | |
""" | |
return Layer(typ, typ, fn) | |
######################################## | |
## ## | |
## EXAMPLE DEPENDENCY DEFINITIONS ## | |
## ## | |
######################################## | |
# INTERFACE | |
class HelloService(Injectable, abc.ABC): | |
@abc.abstractmethod | |
def say_hello(self) -> str: | |
... | |
# IMPLEMENTATION | |
@dataclass | |
class HelloServiceLive(HelloService): | |
def say_hello(self) -> str: | |
return "hello (live)" | |
# IMPLEMENTATION | |
@dataclass | |
class HelloServiceTest(HelloService): | |
def say_hello(self) -> str: | |
return "hello (test)" | |
# INTERFACE | |
class GoodbyeService(Injectable, abc.ABC): | |
@abc.abstractmethod | |
def say_goodbye(self) -> str: | |
... | |
# IMPLEMENTATION | |
@dataclass | |
class GoodbyeServiceLive(GoodbyeService): | |
def say_goodbye(self) -> str: | |
return "goodbye (live)" | |
# IMPLEMENTATION | |
@dataclass | |
class GoodbyeServiceTest(GoodbyeService): | |
def say_goodbye(self) -> str: | |
return "goodbye (test)" | |
# INTERFACE | |
class ConversationService(Injectable, abc.ABC): | |
@abc.abstractmethod | |
def converse(self) -> str: | |
... | |
# IMPLEMENTATION | |
@dataclass | |
class ConversationServiceLive(ConversationService): | |
hello_service: HelloService | |
goodbye_service: GoodbyeService | |
def converse(self) -> str: | |
return self.hello_service.say_hello() + '\n' + self.goodbye_service.say_goodbye() + '\n' | |
# IMPLEMENTATION | |
@dataclass | |
class ConversationServiceSimple(ConversationService): | |
""" | |
Since this class has a str dependency, we probably don't want that to be autowired | |
(what if multiple classes have a str dependency? the same string will be passed to | |
both!) So instead, we'll provide a custom layer to simply construct a value. In the | |
real world, this would probably be constructed from a configuration source, which | |
could have its own layer to pull from a file or the environment. | |
""" | |
conversation: str | |
def converse(self) -> str: | |
return self.conversation | |
@classmethod | |
def converse_layer(cls, conversation: str) -> Layer[ConversationService]: | |
""" | |
Provide this to an environment instead of the class. | |
""" | |
return Layer.impl(ConversationService, cls, lambda _: cls(conversation + '\n')) | |
# INTERFACE | |
class IOService(Injectable): | |
@abc.abstractmethod | |
def write_str(self, value: str): | |
... | |
# IMPLEMENTATION | |
@dataclass | |
class StdIOService(IOService): | |
def write_str(self, value: str): | |
sys.stdout.write(value) | |
# IMPLEMENTATION | |
@dataclass | |
class FileIOService(IOService): | |
""" | |
Like ConversationServiceSimple, this implementation depends on a file object | |
that we wouldn't want to autowire. This case is more complicated, however, because | |
a file object is a resource, not an immutable str. | |
In this case, we'll add a layer to provide a file as a resource. | |
""" | |
file: TextIO | |
def write_str(self, value: str): | |
self.file.write(value) | |
@classmethod | |
def resource_layer(cls, file_path: str) -> Layer[IOService]: | |
@contextmanager | |
def constr(_: Env) -> Iterator[IOService]: | |
with open(file_path, 'w') as file: | |
yield FileIOService(file) | |
return Layer(IOService, cls, constr) | |
# APPLICATION (no distinction between interface and implementation) | |
@dataclass | |
class ConversationApplication(Injectable): | |
conversation_service: ConversationService | |
io_service: IOService | |
def run(self): | |
conversation = self.conversation_service.converse() | |
self.io_service.write_str(conversation) | |
######################################## | |
## ## | |
## EXAMPLE APPLICATION USING DI ## | |
## ## | |
######################################## | |
# Define environment by adding Injectable instances or manual | |
# layers. Comment/uncomment implementations to change how the | |
# application is constructed | |
my_env = Env( | |
ConversationApplication, | |
StdIOService, | |
# FileIOService.resource_layer('output.txt'), | |
ConversationServiceLive, | |
# ConversationServiceSimple.converse_layer('this is the full conversation'), | |
HelloServiceLive, | |
# HelloServiceTest, | |
GoodbyeServiceLive, | |
# GoodbyeServiceTest, | |
) | |
# Access the application type and run it | |
with my_env.context(ConversationApplication) as app: | |
app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment