Skip to content

Instantly share code, notes, and snippets.

@shawn42
Last active July 28, 2023 22:27
Show Gist options
  • Save shawn42/5a6569f10500aafa2d68ff2848f9306b to your computer and use it in GitHub Desktop.
Save shawn42/5a6569f10500aafa2d68ff2848f9306b to your computer and use it in GitHub Desktop.
Simple DI in Python
from functools import wraps
from typing import Any, Callable, TypeVar, Type, cast, get_type_hints
T = TypeVar('T')
class AppContext:
def __init__(self) -> None:
self.dependencies: dict[type, dict[Any, Any]] = {}
self.cache: dict[type, Any] = {}
def is_registered(self, dependency_type: Type[T]) -> bool:
return dependency_type in self.dependencies
def register(self, dependency: Any, **options: Any) -> None:
cache = options.pop('cache', True)
if dependency in self.dependencies:
print(f"WARN: Dependency already registered for type: {dependency}")
if type(dependency) == type:
self.dependencies[dependency] = {"provider": dependency, "cache": cache}
else:
self.dependencies[dependency.__annotations__['return']] = {"provider": dependency, "cache": cache}
def resolve(self, dependency_type: Type[T]) -> T:
if dependency_type not in self.dependencies:
raise ValueError(f"Dependency not registered for type: {dependency_type}")
use_cached = self.dependencies[dependency_type]["cache"]
if use_cached and dependency_type in self.cache:
return cast(T, self.cache[dependency_type])
object = cast(T, self.dependencies[dependency_type]["provider"]())
if use_cached:
self.cache[dependency_type] = object
return object
def set(self, key: Type[T], instance: T) -> None:
self.dependencies[key]["cache"] = True
self.cache[key] = instance
def __setitem__(self, key: Type[T], instance: T) -> None:
self.set(key, instance)
def __getitem__(self, key: Type[T]) -> T:
return self.resolve(key)
def inject_dependencies_from(context: AppContext) -> Callable[[Type[T]], Type[T]]:
def decorator(cls: Type[T]) -> Type[T]:
context.register(cls)
constructor_hints = get_type_hints(cls.__init__)
original_init = cls.__init__
@wraps(cls)
def init_with_dependencies(self, *args: Any, **kwargs: Any) -> None: #type: ignore
# TODO remove 'return' from constructor_hints
sig = constructor_hints.copy()
sig.pop('return')
resolved_args = dict((k, context.resolve(v_type)) for (k, v_type) in sig.items())
original_init(self, **resolved_args, **kwargs)
cls.__init__ = init_with_dependencies # type: ignore
return cls
return decorator
context = AppContext()
class Filesystem:
...
@inject_dependencies_from(context)
class Settings:
def __init__(self, fs: Filesystem) -> None:
print(f"settings: {fs}")
@property
def db_name(self) -> str:
return "example_db"
# Example usage
@inject_dependencies_from(context)
class Database:
def __init__(self, settings: Settings) -> None:
self.db_name = settings.db_name
def use_db(db: Database) -> None:
print(db.db_name) # Output: "example_db"
class Unregistered:
...
if __name__ == "__main__":
# uses the return type of the function
def create_fs() -> Filesystem:
return Filesystem()
context.register(create_fs)
# or just
context.register(Filesystem)
# context.register(Settings) # not needed, comes from the decorator
# TODO: how to cleanly re-register with different options?
context.register(Database, cache=False)
# Resolve dependencies
db_instance = context[Database]
use_db(db_instance) # mypy is cool with the types
# context[Unregistered] # raise Exception
assert context[Filesystem] == context[Filesystem]
assert context[Database] != context[Database]
class FakeDatabase(Database):
def __init__(self) -> None:
...
fake_db = FakeDatabase()
context[Database] = fake_db # now cached
assert context[Database] == fake_db
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment