Last active
February 2, 2024 14:59
-
-
Save tcwalther/ae058c64d5d9078a9f333913718bba95 to your computer and use it in GitHub Desktop.
DelayedInterrupt class - delaying the handling of process signals in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import signal | |
import logging | |
# class based on: http://stackoverflow.com/a/21919644/487556 | |
class DelayedInterrupt(object): | |
def __init__(self, signals): | |
if not isinstance(signals, list) and not isinstance(signals, tuple): | |
signals = [signals] | |
self.sigs = signals | |
def __enter__(self): | |
self.signal_received = {} | |
self.old_handlers = {} | |
for sig in self.sigs: | |
self.signal_received[sig] = False | |
self.old_handlers[sig] = signal.getsignal(sig) | |
def handler(s, frame): | |
self.signal_received[sig] = (s, frame) | |
# Note: in Python 3.5, you can use signal.Signals(sig).name | |
logging.info('Signal %s received. Delaying KeyboardInterrupt.' % sig) | |
self.old_handlers[sig] = signal.getsignal(sig) | |
signal.signal(sig, handler) | |
def __exit__(self, type, value, traceback): | |
for sig in self.sigs: | |
signal.signal(sig, self.old_handlers[sig]) | |
if self.signal_received[sig] and self.old_handlers[sig]: | |
self.old_handlers[sig](*self.signal_received[sig]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import signal | |
from delayedinterrupt import DelayedInterrupt | |
from mock import Mock | |
def test_delayed_interrupt_with_one_signal(): | |
# check behavior without DelayedInterrupt | |
a = Mock() | |
b = Mock() | |
c = Mock() | |
try: | |
a() | |
os.kill(os.getpid(), signal.SIGINT) | |
b() | |
except KeyboardInterrupt: | |
c() | |
a.assert_called_with() | |
b.assert_not_called() | |
c.assert_called_with() | |
# test behavior with DelayedInterrupt | |
a = Mock() | |
b = Mock() | |
c = Mock() | |
try: | |
with DelayedInterrupt(signal.SIGINT): | |
a() | |
os.kill(os.getpid(), signal.SIGINT) | |
b() | |
except KeyboardInterrupt: | |
c() | |
a.assert_called_with() | |
b.assert_called_with() | |
c.assert_called_with() | |
def test_delayed_interrupt_with_multiple_signals(): | |
a = Mock() | |
b = Mock() | |
c = Mock() | |
try: | |
with DelayedInterrupt([signal.SIGTERM, signal.SIGINT]): | |
a() | |
os.kill(os.getpid(), signal.SIGINT) | |
os.kill(os.getpid(), signal.SIGTERM) | |
b() | |
except KeyboardInterrupt: | |
c() | |
a.assert_called_with() | |
b.assert_called_with() | |
c.assert_called_with() |
Wanted to use this for my project so I refactored it in a way I consider more maintainable, and added an option for early exits if an interrupt is repeated enough times, as well as a subclass for similarly ignoring (not delaying) exceptions, unless early exit is triggered.
import logging
import signal
class DelaySignals:
def __init__( self, signals_to_delay:list[signal.Signals]=signal.SIGINT, unless_repeated_n_times:int=False ):
"""
A class which intercepts chosen incoming signals after __enter__ and delays them till __exit__ is reached.
Args:
signals_to_delay (list[signal.Signals], optional): The signal or list/tuple of signals to delay. Defaults to signal.SIGINT.
unless_repeated_n_times (int|bool, optional): If a signal is received N amount of times, allow it through by exiting the class prematurely. Defaults to False.
Example usage:
```py
with DelaySignals( signal.SIGINT ):
time.sleep(10) # If a SIGINT signal (KeyboardInterrupt) is received in this scope, it is delayed until the scope is exited
```
Based on:
http://stackoverflow.com/a/21919644/487556
https://gist.github.com/tcwalther/ae058c64d5d9078a9f333913718bba95
"""
self.signals_to_delay = signals_to_delay if type(signals_to_delay) in [list, tuple] else [signals_to_delay]
self.unless_repeated_n_times = unless_repeated_n_times
def __enter__( self ):
self.inboxes = {}
for sig_type in self.signals_to_delay:
self.inboxes[sig_type] = {'received': [], 'handler': signal.getsignal(sig_type)}
signal.signal( sig_type, self.signal_handler )
def signal_handler( self, sig, frame ):
self.inboxes[sig]['received'].append( (sig, frame) )
if self.unless_repeated_n_times and len( self.inboxes[sig]['received'] ) >= self.unless_repeated_n_times:
logging.warn( f"{__class__.__name__}: Signal {sig} repeated enough times ({self.unless_repeated_n_times}) to pass through, exiting early." )
self.__exit__()
else:
logging.info( f"{__class__.__name__}: Signal {sig} handled." )
def __exit__( self, *_ ):
for sig, inbox in self.inboxes.items():
signal.signal( sig, inbox['handler'] )
if inbox['handler']:
for msg in inbox['received']: inbox['handler']( *msg )
else:
logging.warn(f"{__class__.__name__}: Signal {sig} had no prior handler, skipping.")
class IgnoreSignals(DelaySignals):
def __exit__( self, *_ ):
for sig, inbox in self.inboxes.items():
signal.signal( sig, inbox['handler'] )
if inbox['handler']:
if len(inbox['received']) >= self.unless_repeated_n_times: inbox['handler']( *inbox['received'][0] )
else:
logging.warn(f"{__class__.__name__}: Signal {sig} had no prior handler, skipping.")
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@irvinlim He's referring to line 18
should be
See this answer on SO