Skip to content

Instantly share code, notes, and snippets.

@snewell92
Last active December 7, 2023 17:06
Show Gist options
  • Save snewell92/ea76cf67e180af06b1390817e70bb02f to your computer and use it in GitHub Desktop.
Save snewell92/ea76cf67e180af06b1390817e70bb02f to your computer and use it in GitHub Desktop.
Wireup DI Override
from typing import (
Any,
Dict,
Generic,
Literal,
Optional,
Tuple,
TypeVar,
)
from wireup import DependencyContainer, container
from wireup.ioc.types import ContainerProxyQualifierValue
TTarget = TypeVar("TTarget")
TOverride = TypeVar("TOverride")
class ContainerOverride(Generic[TTarget, TOverride]):
"""
A DI Override born from discussion here: https://github.com/maldoinc/wireup/issues/7#issuecomment-1823476980
Latest version also published in this gist: https://gist.github.com/snewell92/ea76cf67e180af06b1390817e70bb02f
"""
def __init__(
self,
target: TTarget,
override: TOverride,
dependency_container: Optional[DependencyContainer] = None,
qualifier: ContainerProxyQualifierValue = None,
):
self.container = dependency_container if dependency_container else container
self.target = target
self.override = override
self.qualifier = qualifier
self.__di_objects: Dict[
Tuple[TTarget, ContainerProxyQualifierValue], Any
] = self.container._DependencyContainer__initialized_objects # type: ignore [attr-defined]
self.existing_instance = self.__di_objects.get((target, qualifier))
def set_value(self, value: Any) -> None:
self.__di_objects[(self.target, self.qualifier)] = value
def __enter__(self) -> TOverride:
self.set_value(self.override)
return self.override
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[True]:
if self.existing_instance:
self.set_value(self.existing_instance)
else:
del self.__di_objects[(self.target, self.qualifier)]
return True
class ContainerOverrides(Generic[TTarget, TOverride]):
"""
A List version of ContainerOverride
Example:
```python
overrides = [
(TargetType, Mock(), None),
(OtherTarget, self.my_mock, None),
(FinalType, MemoryImpl(), None)
]
with ContainerOverrides(overrides):
do_some_work()
```
"""
def __init__(
self,
overrides: List[
Tuple[TTarget, TOverride, Optional[ContainerProxyQualifierValue]]
],
dependency_container: Optional[DependencyContainer] = None,
):
self.container = dependency_container if dependency_container else container
self.overrides = overrides
self.__di_objects: Dict[
Tuple[TTarget, ContainerProxyQualifierValue], Any
] = self.container._DependencyContainer__initialized_objects # type: ignore [attr-defined]
self.existing_instances = [
(
override[0],
self.__di_objects.get((override[0], override[2])),
override[2],
)
for override in self.overrides
]
assert len(self.existing_instances) == len(
self.overrides
), "Each override be for an existing target"
def set_value(self, target: Any, value: Any, qualifier: Any) -> None:
self.__di_objects[(target, qualifier)] = value
def __enter__(self) -> List[TOverride]:
for override in self.overrides:
self.set_value(override[0], override[1], override[2])
return [override[1] for override in self.overrides]
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[True]:
"""
For the list case we do not bother del the __di_objects if there are none as we validate
while constructing that you must provide overrides of existing targets
"""
if self.existing_instances:
for existing_instance in self.existing_instances:
self.set_value(
existing_instance[0], existing_instance[1], existing_instance[2]
)
return True
from unittest import TestCase
from app import app
class TestSample(TestCase):
def test_sample_inline(self) -> None:
with ContainerOverride(FeatureFlagService, Mock(FeatureFlagService)) as mock:
mock.is_calendar_integration_enabled.return_value = True
with app.test_client() as client:
response = client.get("/some/fake/endpoint")
self.assertEqual(200, response.status_code)
def test_sample_reusale(self) -> None:
mock = Mock(FakeService)
with ContainerOverride(FeatureFlagService, mock):
mock.fake_method.return_value = False
with app.test_client() as client:
response = client.get("/some/fake/endpoint")
self.assertEqual(400, response.status_code)
def test_multiple(self) -> None:
mock = Mock(SomeService)
mock2 = Mock(SecondService)
overrides = [
(SomeService, mock, None),
(SecondService, mock2, None)
]
with ContainerOverrides(overrides):
with app.test_client() as client:
response = client.get("/some/workflow")
mock2.add.assertCalledOnce()
self.assertEqual(200, response.status_code)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment