Last active
December 7, 2023 17:06
-
-
Save snewell92/ea76cf67e180af06b1390817e70bb02f to your computer and use it in GitHub Desktop.
Wireup DI Override
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import ( | |
Any, | |
Dict, | |
Generic, | |
Literal, | |
Optional, | |
Tuple, | |
TypeVar, | |
) | |
from wireup import DependencyContainer, container | |
from wireup.ioc.types import ContainerProxyQualifierValue | |
TTarget = TypeVar("TTarget") | |
TOverride = TypeVar("TOverride") | |
class ContainerOverride(Generic[TTarget, TOverride]): | |
""" | |
A DI Override born from discussion here: https://github.com/maldoinc/wireup/issues/7#issuecomment-1823476980 | |
Latest version also published in this gist: https://gist.github.com/snewell92/ea76cf67e180af06b1390817e70bb02f | |
""" | |
def __init__( | |
self, | |
target: TTarget, | |
override: TOverride, | |
dependency_container: Optional[DependencyContainer] = None, | |
qualifier: ContainerProxyQualifierValue = None, | |
): | |
self.container = dependency_container if dependency_container else container | |
self.target = target | |
self.override = override | |
self.qualifier = qualifier | |
self.__di_objects: Dict[ | |
Tuple[TTarget, ContainerProxyQualifierValue], Any | |
] = self.container._DependencyContainer__initialized_objects # type: ignore [attr-defined] | |
self.existing_instance = self.__di_objects.get((target, qualifier)) | |
def set_value(self, value: Any) -> None: | |
self.__di_objects[(self.target, self.qualifier)] = value | |
def __enter__(self) -> TOverride: | |
self.set_value(self.override) | |
return self.override | |
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[True]: | |
if self.existing_instance: | |
self.set_value(self.existing_instance) | |
else: | |
del self.__di_objects[(self.target, self.qualifier)] | |
return True | |
class ContainerOverrides(Generic[TTarget, TOverride]): | |
""" | |
A List version of ContainerOverride | |
Example: | |
```python | |
overrides = [ | |
(TargetType, Mock(), None), | |
(OtherTarget, self.my_mock, None), | |
(FinalType, MemoryImpl(), None) | |
] | |
with ContainerOverrides(overrides): | |
do_some_work() | |
``` | |
""" | |
def __init__( | |
self, | |
overrides: List[ | |
Tuple[TTarget, TOverride, Optional[ContainerProxyQualifierValue]] | |
], | |
dependency_container: Optional[DependencyContainer] = None, | |
): | |
self.container = dependency_container if dependency_container else container | |
self.overrides = overrides | |
self.__di_objects: Dict[ | |
Tuple[TTarget, ContainerProxyQualifierValue], Any | |
] = self.container._DependencyContainer__initialized_objects # type: ignore [attr-defined] | |
self.existing_instances = [ | |
( | |
override[0], | |
self.__di_objects.get((override[0], override[2])), | |
override[2], | |
) | |
for override in self.overrides | |
] | |
assert len(self.existing_instances) == len( | |
self.overrides | |
), "Each override be for an existing target" | |
def set_value(self, target: Any, value: Any, qualifier: Any) -> None: | |
self.__di_objects[(target, qualifier)] = value | |
def __enter__(self) -> List[TOverride]: | |
for override in self.overrides: | |
self.set_value(override[0], override[1], override[2]) | |
return [override[1] for override in self.overrides] | |
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[True]: | |
""" | |
For the list case we do not bother del the __di_objects if there are none as we validate | |
while constructing that you must provide overrides of existing targets | |
""" | |
if self.existing_instances: | |
for existing_instance in self.existing_instances: | |
self.set_value( | |
existing_instance[0], existing_instance[1], existing_instance[2] | |
) | |
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from unittest import TestCase | |
from app import app | |
class TestSample(TestCase): | |
def test_sample_inline(self) -> None: | |
with ContainerOverride(FeatureFlagService, Mock(FeatureFlagService)) as mock: | |
mock.is_calendar_integration_enabled.return_value = True | |
with app.test_client() as client: | |
response = client.get("/some/fake/endpoint") | |
self.assertEqual(200, response.status_code) | |
def test_sample_reusale(self) -> None: | |
mock = Mock(FakeService) | |
with ContainerOverride(FeatureFlagService, mock): | |
mock.fake_method.return_value = False | |
with app.test_client() as client: | |
response = client.get("/some/fake/endpoint") | |
self.assertEqual(400, response.status_code) | |
def test_multiple(self) -> None: | |
mock = Mock(SomeService) | |
mock2 = Mock(SecondService) | |
overrides = [ | |
(SomeService, mock, None), | |
(SecondService, mock2, None) | |
] | |
with ContainerOverrides(overrides): | |
with app.test_client() as client: | |
response = client.get("/some/workflow") | |
mock2.add.assertCalledOnce() | |
self.assertEqual(200, response.status_code) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment