Created
December 22, 2015 08:00
-
-
Save mgedmin/a91872054884dbaaa344 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/Lib/unittest/signals.py b/Lib/unittest/signals.py | |
index e6a5fc5..1ba8b43 100644 | |
--- a/Lib/unittest/signals.py | |
+++ b/Lib/unittest/signals.py | |
@@ -6,6 +6,10 @@ from functools import wraps | |
__unittest = True | |
+def _do_nothing_handler(unused_signum, unused_frame): | |
+ pass | |
+ | |
+ | |
class _InterruptHandler(object): | |
def __init__(self, default_handler): | |
self.called = False | |
@@ -17,8 +21,7 @@ class _InterruptHandler(object): | |
elif default_handler == signal.SIG_IGN: | |
# Not quite the same thing as SIG_IGN, but the closest we | |
# can make it: do nothing. | |
- def default_handler(unused_signum, unused_frame): | |
- pass | |
+ default_handler = _do_nothing_handler | |
else: | |
raise TypeError("expected SIGINT signal handler to be " | |
"signal.SIG_IGN, signal.SIG_DFL, or a " | |
@@ -31,12 +34,19 @@ class _InterruptHandler(object): | |
# if we aren't the installed handler, then delegate immediately | |
# to the default handler | |
self.default_handler(signum, frame) | |
+ return | |
if self.called: | |
self.default_handler(signum, frame) | |
+ return | |
self.called = True | |
+ stopped = False | |
for result in _results.keys(): | |
result.stop() | |
+ stopped = True | |
+ if not stopped: | |
+ # if there aren't any registered results, delegate immediately | |
+ self.default_handler(signum, frame) | |
_results = weakref.WeakKeyDictionary() | |
def registerResult(result): | |
diff --git a/Lib/unittest/test/test_signals.py b/Lib/unittest/test/test_signals.py | |
new file mode 100644 | |
index 0000000..5116796 | |
--- /dev/null | |
+++ b/Lib/unittest/test/test_signals.py | |
@@ -0,0 +1,61 @@ | |
+import inspect | |
+import signal | |
+import unittest | |
+from unittest import mock, signals | |
+ | |
+ | |
+class Test_InterruptHandler(unittest.TestCase): | |
+ | |
+ def test_init_recognizes_default_handler(self): | |
+ handler = signals._InterruptHandler(signal.SIG_DFL) | |
+ self.assertEqual(handler.default_handler, signal.default_int_handler) | |
+ | |
+ def test_init_recognizes_sigign(self): | |
+ handler = signals._InterruptHandler(signal.SIG_IGN) | |
+ self.assertEqual(handler.default_handler, signals._do_nothing_handler) | |
+ | |
+ def test_init_refuses_unexpected_values(self): | |
+ with self.assertRaises(TypeError): | |
+ signals._InterruptHandler(42) | |
+ | |
+ @mock.patch('unittest.signals._results', {}) | |
+ def test_call_with_no_registered_results(self): | |
+ default_handler = mock.Mock() | |
+ handler = signals._InterruptHandler(default_handler) | |
+ with mock.patch('signal.getsignal', lambda sig: handler): | |
+ handler(signal.SIGINT, inspect.currentframe()) | |
+ self.assertTrue(handler.called) | |
+ self.assertEqual(default_handler.call_count, 1) | |
+ | |
+ def test_call_with_registered_results(self): | |
+ default_handler = mock.Mock() | |
+ result = mock.Mock() | |
+ handler = signals._InterruptHandler(default_handler) | |
+ with mock.patch('signal.getsignal', lambda sig: handler), \ | |
+ mock.patch('unittest.signals._results', {result: 1}): | |
+ handler(signal.SIGINT, inspect.currentframe()) | |
+ self.assertTrue(handler.called) | |
+ self.assertEqual(default_handler.call_count, 0) | |
+ self.assertEqual(result.stop.call_count, 1) | |
+ | |
+ def test_call_twice(self): | |
+ default_handler = mock.Mock() | |
+ result = mock.Mock() | |
+ handler = signals._InterruptHandler(default_handler) | |
+ with mock.patch('signal.getsignal', lambda sig: handler), \ | |
+ mock.patch('unittest.signals._results', {result: 1}): | |
+ handler(signal.SIGINT, inspect.currentframe()) | |
+ handler(signal.SIGINT, inspect.currentframe()) | |
+ self.assertTrue(handler.called) | |
+ self.assertEqual(default_handler.call_count, 1) | |
+ self.assertEqual(result.stop.call_count, 1) | |
+ | |
+ def test_call_when_not_installed(self): | |
+ default_handler = mock.Mock() | |
+ result = mock.Mock() | |
+ handler = signals._InterruptHandler(default_handler) | |
+ with mock.patch('unittest.signals._results', {result: 1}): | |
+ handler(signal.SIGINT, inspect.currentframe()) | |
+ self.assertFalse(handler.called) | |
+ self.assertEqual(default_handler.call_count, 1) | |
+ self.assertEqual(result.stop.call_count, 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment