Skip to content

Instantly share code, notes, and snippets.

@tcwalther
Last active February 2, 2024 14:59
Show Gist options
  • Save tcwalther/ae058c64d5d9078a9f333913718bba95 to your computer and use it in GitHub Desktop.
Save tcwalther/ae058c64d5d9078a9f333913718bba95 to your computer and use it in GitHub Desktop.
DelayedInterrupt class - delaying the handling of process signals in Python
import signal
import logging
# class based on: http://stackoverflow.com/a/21919644/487556
class DelayedInterrupt(object):
def __init__(self, signals):
if not isinstance(signals, list) and not isinstance(signals, tuple):
signals = [signals]
self.sigs = signals
def __enter__(self):
self.signal_received = {}
self.old_handlers = {}
for sig in self.sigs:
self.signal_received[sig] = False
self.old_handlers[sig] = signal.getsignal(sig)
def handler(s, frame):
self.signal_received[sig] = (s, frame)
# Note: in Python 3.5, you can use signal.Signals(sig).name
logging.info('Signal %s received. Delaying KeyboardInterrupt.' % sig)
self.old_handlers[sig] = signal.getsignal(sig)
signal.signal(sig, handler)
def __exit__(self, type, value, traceback):
for sig in self.sigs:
signal.signal(sig, self.old_handlers[sig])
if self.signal_received[sig] and self.old_handlers[sig]:
self.old_handlers[sig](*self.signal_received[sig])
import os
import signal
from delayedinterrupt import DelayedInterrupt
from mock import Mock
def test_delayed_interrupt_with_one_signal():
# check behavior without DelayedInterrupt
a = Mock()
b = Mock()
c = Mock()
try:
a()
os.kill(os.getpid(), signal.SIGINT)
b()
except KeyboardInterrupt:
c()
a.assert_called_with()
b.assert_not_called()
c.assert_called_with()
# test behavior with DelayedInterrupt
a = Mock()
b = Mock()
c = Mock()
try:
with DelayedInterrupt(signal.SIGINT):
a()
os.kill(os.getpid(), signal.SIGINT)
b()
except KeyboardInterrupt:
c()
a.assert_called_with()
b.assert_called_with()
c.assert_called_with()
def test_delayed_interrupt_with_multiple_signals():
a = Mock()
b = Mock()
c = Mock()
try:
with DelayedInterrupt([signal.SIGTERM, signal.SIGINT]):
a()
os.kill(os.getpid(), signal.SIGINT)
os.kill(os.getpid(), signal.SIGTERM)
b()
except KeyboardInterrupt:
c()
a.assert_called_with()
b.assert_called_with()
c.assert_called_with()
@irvinlim
Copy link

To fix this add sig=sig to the parameter list.

@benrg Can you clarify which method's parameter list you are referring to?

@kliberty
Copy link

@irvinlim He's referring to line 18

def handler(s, frame):
    self.signal_received[sig] = (s, frame)
                       #  ^ this sig is captured by reference

should be

def handler(s, frame, sig=sig):
    self.signal_received[sig] = (s, frame)

See this answer on SO

@jaggujajamensan
Copy link

jaggujajamensan commented Feb 2, 2024

Wanted to use this for my project so I refactored it in a way I consider more maintainable, and added an option for early exits if an interrupt is repeated enough times, as well as a subclass for similarly ignoring (not delaying) exceptions, unless early exit is triggered.

import logging
import signal

class DelaySignals:
    def __init__( self, signals_to_delay:list[signal.Signals]=signal.SIGINT, unless_repeated_n_times:int=False ):
        """
        A class which intercepts chosen incoming signals after __enter__ and delays them till __exit__ is reached.
        
        Args:
            signals_to_delay (list[signal.Signals], optional):  The signal or list/tuple of signals to delay. Defaults to signal.SIGINT.
            unless_repeated_n_times (int|bool, optional):       If a signal is received N amount of times, allow it through by exiting the class prematurely. Defaults to False.
            
        Example usage:
            ```py
            with DelaySignals( signal.SIGINT ):
                time.sleep(10)  # If a SIGINT signal  (KeyboardInterrupt)  is received in this scope, it is delayed until the scope is exited
            ```
    
        Based on: 
            http://stackoverflow.com/a/21919644/487556
            https://gist.github.com/tcwalther/ae058c64d5d9078a9f333913718bba95
        """
        self.signals_to_delay           = signals_to_delay if type(signals_to_delay) in [list, tuple] else [signals_to_delay]
        self.unless_repeated_n_times    = unless_repeated_n_times


    def __enter__( self ):
        self.inboxes = {}
        for sig_type in self.signals_to_delay:
            self.inboxes[sig_type] = {'received': [], 'handler': signal.getsignal(sig_type)}
            signal.signal( sig_type, self.signal_handler )


    def signal_handler( self, sig, frame ):
        self.inboxes[sig]['received'].append( (sig, frame) )
        
        if self.unless_repeated_n_times and len( self.inboxes[sig]['received'] ) >= self.unless_repeated_n_times: 
            logging.warn( f"{__class__.__name__}: Signal {sig} repeated enough times ({self.unless_repeated_n_times}) to pass through, exiting early." )
            self.__exit__()
        else:
            logging.info( f"{__class__.__name__}: Signal {sig} handled." )

    
    def __exit__( self, *_ ):
        for sig, inbox in self.inboxes.items():
            signal.signal( sig, inbox['handler'] )
            if inbox['handler']:    
                for msg in inbox['received']: inbox['handler']( *msg )
            else:                   
                logging.warn(f"{__class__.__name__}: Signal {sig} had no prior handler, skipping.")


class IgnoreSignals(DelaySignals):
    def __exit__( self, *_ ): 
        for sig, inbox in self.inboxes.items():
            signal.signal( sig, inbox['handler'] )
            if inbox['handler']:
                if len(inbox['received']) >= self.unless_repeated_n_times: inbox['handler']( *inbox['received'][0] )
            else:
                logging.warn(f"{__class__.__name__}: Signal {sig} had no prior handler, skipping.")
        

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment