Last active
July 28, 2023 22:27
-
-
Save shawn42/5a6569f10500aafa2d68ff2848f9306b to your computer and use it in GitHub Desktop.
Simple DI in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import wraps | |
from typing import Any, Callable, TypeVar, Type, cast, get_type_hints | |
T = TypeVar('T') | |
class AppContext: | |
def __init__(self) -> None: | |
self.dependencies: dict[type, dict[Any, Any]] = {} | |
self.cache: dict[type, Any] = {} | |
def is_registered(self, dependency_type: Type[T]) -> bool: | |
return dependency_type in self.dependencies | |
def register(self, dependency: Any, **options: Any) -> None: | |
cache = options.pop('cache', True) | |
if dependency in self.dependencies: | |
print(f"WARN: Dependency already registered for type: {dependency}") | |
if type(dependency) == type: | |
self.dependencies[dependency] = {"provider": dependency, "cache": cache} | |
else: | |
self.dependencies[dependency.__annotations__['return']] = {"provider": dependency, "cache": cache} | |
def resolve(self, dependency_type: Type[T]) -> T: | |
if dependency_type not in self.dependencies: | |
raise ValueError(f"Dependency not registered for type: {dependency_type}") | |
use_cached = self.dependencies[dependency_type]["cache"] | |
if use_cached and dependency_type in self.cache: | |
return cast(T, self.cache[dependency_type]) | |
object = cast(T, self.dependencies[dependency_type]["provider"]()) | |
if use_cached: | |
self.cache[dependency_type] = object | |
return object | |
def set(self, key: Type[T], instance: T) -> None: | |
self.dependencies[key]["cache"] = True | |
self.cache[key] = instance | |
def __setitem__(self, key: Type[T], instance: T) -> None: | |
self.set(key, instance) | |
def __getitem__(self, key: Type[T]) -> T: | |
return self.resolve(key) | |
def inject_dependencies_from(context: AppContext) -> Callable[[Type[T]], Type[T]]: | |
def decorator(cls: Type[T]) -> Type[T]: | |
context.register(cls) | |
constructor_hints = get_type_hints(cls.__init__) | |
original_init = cls.__init__ | |
@wraps(cls) | |
def init_with_dependencies(self, *args: Any, **kwargs: Any) -> None: #type: ignore | |
# TODO remove 'return' from constructor_hints | |
sig = constructor_hints.copy() | |
sig.pop('return') | |
resolved_args = dict((k, context.resolve(v_type)) for (k, v_type) in sig.items()) | |
original_init(self, **resolved_args, **kwargs) | |
cls.__init__ = init_with_dependencies # type: ignore | |
return cls | |
return decorator | |
context = AppContext() | |
class Filesystem: | |
... | |
@inject_dependencies_from(context) | |
class Settings: | |
def __init__(self, fs: Filesystem) -> None: | |
print(f"settings: {fs}") | |
@property | |
def db_name(self) -> str: | |
return "example_db" | |
# Example usage | |
@inject_dependencies_from(context) | |
class Database: | |
def __init__(self, settings: Settings) -> None: | |
self.db_name = settings.db_name | |
def use_db(db: Database) -> None: | |
print(db.db_name) # Output: "example_db" | |
class Unregistered: | |
... | |
if __name__ == "__main__": | |
# uses the return type of the function | |
def create_fs() -> Filesystem: | |
return Filesystem() | |
context.register(create_fs) | |
# or just | |
context.register(Filesystem) | |
# context.register(Settings) # not needed, comes from the decorator | |
# TODO: how to cleanly re-register with different options? | |
context.register(Database, cache=False) | |
# Resolve dependencies | |
db_instance = context[Database] | |
use_db(db_instance) # mypy is cool with the types | |
# context[Unregistered] # raise Exception | |
assert context[Filesystem] == context[Filesystem] | |
assert context[Database] != context[Database] | |
class FakeDatabase(Database): | |
def __init__(self) -> None: | |
... | |
fake_db = FakeDatabase() | |
context[Database] = fake_db # now cached | |
assert context[Database] == fake_db |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment