Last active
August 29, 2015 14:04
-
-
Save n0phx/5b7a7d089b79395bbac9 to your computer and use it in GitHub Desktop.
Helper decorator for Django unittests, used to make sure that the expected signals were fired the expected number of times.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Helper decorator for Django unittests, used to make sure that the expected | |
signals were fired the expected number of times. | |
Usage: | |
class MyTestCase(TestCase): | |
@assert_signal_fired({ | |
'my_signal_func': (my_signal_func, 1), | |
'other_signal_func': (other_signal_func, 0) | |
}) | |
def test_something(self): | |
...test code here | |
It will connect a temporary handler function to the specified signals, count | |
the number of times those signals were invoked and at the end of the test | |
assert whether the signals were called the expected number of times, and will | |
properly clean up after itself by disconnecting the temporary signal handlers. | |
""" | |
from functools import partial | |
def assert_signal_fired(expectations): | |
_counters = {} | |
def signal_fired_handler(signal_name, *args, **kwargs): | |
_counters[signal_name] = _counters.get(signal_name, 0) + 1 | |
def _assert_signal_fired(func): | |
def __assert_signal_fired(self, *args, **kwargs): | |
_connected = [] | |
for signal_name, (signal_func, call_count) in expectations.items(): | |
handler = partial(signal_fired_handler, signal_name) | |
signal_func.connect(handler) | |
_connected.append((signal_name, | |
signal_func, | |
call_count, | |
handler)) | |
result = func(self, *args, **kwargs) | |
for conn in _connected: | |
(signal_name, signal_func, expected_call_count, handler) = conn | |
signal_func.disconnect(handler) | |
call_count = _counters.get(signal_name, 0) | |
msg = ('{0} was expected to be called {1} time(s), ' | |
'but was called {2} time(s).') | |
msg = msg.format(signal_name, expected_call_count, call_count) | |
self.assertEqual(call_count, expected_call_count, msg=msg) | |
return result | |
return __assert_signal_fired | |
return _assert_signal_fired |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment