Created
December 16, 2017 21:58
-
-
Save guyjacks/d1704cc3f57805b6973a5d519d6f7d7d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Call(object): | |
def __init__(self, name, *args, **kwargs): | |
self.name = name | |
self.args = args | |
self.kwargs = kwargs | |
def __eq__(self, other): | |
names_are_equal = self.name == other.name | |
args_are_equal = self.args == other.args | |
kwargs_are_equal = self.kwargs == other.kwargs | |
return names_are_equal and args_are_equal and kwargs_are_equal | |
def __repr__(self): | |
return "{} - {} - {}".format(self.name, self.args, self.kwargs) | |
class GenericFactory(object): | |
def __init__(self, TypeClass, *args, **kwargs): | |
self.TypeClass = TypeClass | |
self.default_args = *args | |
self.default_kwargs = **kwargs | |
def create(caller, call, *args, **kwargs): | |
args = *args or self.default_args | |
kwargs = **kwargs or self.default_kwargs | |
return self.TypeClass(caller, call, args, kwargs) | |
class MethodSpy(object): | |
def __init__(self, caller, name, members = [], = None, default_args = None, default_kwargs = None): | |
# caller stores the Class instance that this method was called against | |
self.caller = caller | |
# the name of the method being spied | |
self.name = name | |
# raw result set of the queryset | |
self.members = members | |
# Store the value returned by the last call to this method - see __call__() | |
self.returned = None | |
self.return_value = None | |
# Should be a class so __call__ can use it self.return_value_type_factory.create() | |
self.factory = GenericFactory(DefaultClassType, default_args, default_kwargs) | |
# Store all calls to this method | |
# REFACTOR | |
# should be a list of tuples: | |
# - Each tuple should contain a call object and its return value | |
self.calls = [] | |
def __getattr__(self, attr): | |
""" | |
This allows chaining. | |
base_query_set.method.method.method.called_with() | |
Each method in the chain returns a QuerySetMock, but what we need is to return the MethodSpy. This overrides | |
the default so that it returns the MethodSpy instead of the QuerySetMock. | |
:param attr: | |
:return: | |
""" | |
return self.returned.spies[attr] | |
def __call__(self, *args, **kwargs): | |
call = Call( | |
self.name, | |
*args, | |
**kwargs | |
) | |
self.calls.append(call) | |
if self.return_value: | |
self.returned = self.return_value | |
else: | |
self.returned = self.return_value = self.factory.create(self.caller, call) | |
return self.returned | |
def return_value_when(self, value, *args, **kwargs): | |
""" | |
Return the supplied value when args and kwargs are provided | |
i.e. qs_mock.filter.return_value_when(some_qs, 1000, a=10000) | |
This method spy will return some_qs when this method spy is called with 1000 and a=10000 | |
:param args: | |
:return: | |
""" | |
pass | |
def has_call(self, expected_call): | |
for actual_call in self.calls: | |
if actual_call == expected_call: | |
return True | |
return False | |
def has_calls(self, expected_calls): | |
answer = True | |
for expected_call in expected_calls: | |
answer = answer and self.has_call(expected_call) | |
return answer | |
def called_with(self, *args, **kwargs): | |
expected_call = Call(self.name, *args, **kwargs) | |
for actual_call in self.calls: | |
if actual_call == expected_call: | |
return True | |
return False | |
def chain_called(self, expected_calls): | |
""" | |
:param calls: Must be a list | |
:return: | |
""" | |
# Check the last call in the chain | |
expected_call = expected_calls[-1] | |
answer = self.has_call(expected_call) | |
# Check the remaining calls in the chain | |
caller = self.caller | |
for expected_call in reversed(expected_calls[0:-1]): | |
answer = answer and expected_call == caller.call | |
caller = caller.parent | |
return answer | |
class ChainableMock(object): | |
def __init__(self, members, parent=None, call=None): | |
self.members = members | |
self.parent = parent | |
self.call = call | |
self.spies = {} | |
# This allows us to track th current chain | |
def __iter__(self): | |
for member in self.members: | |
yield member | |
def __getattr__(self, name): | |
if name in self.spies: | |
return self.spies[name] | |
else: | |
spy = MethodSpy(self, name, self.members) | |
self.spies[name] = spy | |
return spy |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment