Skip to content

Instantly share code, notes, and snippets.

@tonyseek
Created March 6, 2023 14:15
Show Gist options
  • Save tonyseek/963896eb5aac6198f86c29cb746f1459 to your computer and use it in GitHub Desktop.
Save tonyseek/963896eb5aac6198f86c29cb746f1459 to your computer and use it in GitHub Desktop.
Elect a leader in multiprocess program with unnamed domain socket
import os
import signal
import collections
import logging
import resource
logging.basicConfig(level=logging.DEBUG)
class WorkerManager:
"""Spawns and monitors worker processes."""
logger = logging.getLogger(__name__)
def __init__(self, worker_num, worker_func):
self.ctrl_pid = os.getpid()
self.ctrl_stopped = False
self.worker_num = worker_num
self.worker_pids = set()
self.worker_func = worker_func
self.worker_exit_queue = collections.deque()
self.worker_exit_updated = False
self.setup_signals()
def setup_signals(self):
signal.signal(signal.SIGCHLD, self.handle_sigchild)
signal.signal(signal.SIGUSR1, self.handle_sigusr1)
signal.signal(signal.SIGINT, self.handle_cleanup)
signal.signal(signal.SIGTERM, self.handle_cleanup)
def reset_signals(self):
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
signal.signal(signal.SIGUSR1, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
def run(self):
for idx in range(self.worker_num):
if not self.spawn_worker(f'NEW-{idx}'):
return
self.print_tree()
while not self.ctrl_stopped:
try:
signal.pause()
while len(self.worker_exit_queue) > 0:
(pid, status) = self.worker_exit_queue.popleft()
self.worker_pids.remove(pid)
self.logger.info(
'Found worker died (pid=%d, status=%d)', pid, status)
finally:
self.worker_exit_updated = False
if len(self.worker_pids) < self.worker_num:
for idx in range(len(self.worker_pids), self.worker_num):
if not self.spawn_worker(f'SUPPLY-{idx}'):
return
self.print_tree()
while True:
try:
(pid, status) = os.waitpid(-1, 0)
except ChildProcessError:
break
if pid <= 0:
break
self.logger.info(
'Cleanup worker (pid=%d, status=%d)', pid, status)
def print_tree(self):
worker_pids = sorted(self.worker_pids)
print(f'CTRL: {self.ctrl_pid}')
for idx, pid in enumerate(worker_pids[:-1]):
print(f'├─ {idx:2}: {pid}')
for pid in worker_pids[-1:]:
print(f'└─ {len(worker_pids) - 1:2}: {pid}')
def spawn_worker(self, annotation):
if os.getpid() != self.ctrl_pid:
raise RuntimeError('Please report a bug')
worker_pid = os.fork()
if worker_pid == 0:
self.reset_signals()
self.worker_func(self.ctrl_pid)
return False
else:
self.logger.info(
'Spawned worker (pid=%d) # %s', worker_pid, annotation)
self.worker_pids.add(worker_pid)
return True
def handle_sigchild(self, signum, frame):
del frame
if signum != signal.SIGCHLD:
return
while True:
try:
(pid, status) = os.waitpid(-1, os.WNOHANG)
except ChildProcessError:
break
if pid <= 0:
break
self.worker_exit_queue.append((pid, status))
self.worker_exit_updated = True
def handle_sigusr1(self, signum, frame):
del frame
if signum != signal.SIGUSR1:
return
self.print_tree()
def handle_cleanup(self, signum, frame):
del frame
if signum not in (signal.SIGTERM, signal.SIGINT):
return
for pid in self.worker_pids:
try:
os.kill(pid, signal.SIGTERM)
except ProcessLookupError:
continue
self.ctrl_stopped = True
def worker(ctrl_pid):
import socket
import errno
import logging
import enum
import time
import select
logger = logging.getLogger(f'worker-{ctrl_pid}:{os.getpid()}')
class Role(enum.Enum):
LEADER = 1
FOLLOWER = 2
VOTING = 3
# Uses the connection state machine of TCP to elect a leader
addr = "\0foo-bar-baz"
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
role = Role.VOTING
while True:
if role == Role.VOTING:
try:
sock.bind(addr) # unnamed domain socket (linux-only)
except OSError as e:
if e.errno == errno.EADDRINUSE:
role = Role.FOLLOWER
else:
raise
else:
sock.listen(1)
role = Role.LEADER
if role == Role.LEADER:
logger.info("I am the leader now")
while True:
logger.debug("I am doing leader's work")
time.sleep(10) # Finish leader's work
else:
try:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(addr)
except ConnectionRefusedError:
role = Role.VOTING # Try to be a new leader
logger.debug("Try to be a new leader since we cannot connect")
continue
while True:
fds, _, _ = select.select([sock], [], [], 30)
if fds and fds[0] == sock:
try:
rx = sock.recv(1)
except ConnectionResetError:
rx = ''
if rx == '': # EOF
role = Role.VOTING # Try to be a new leader
logger.debug("Try to be a new leader since EOF")
break
if __name__ == '__main__':
# prevent from fork-bomb on debugging
resource.setrlimit(resource.RLIMIT_NPROC, (100, 100))
ctrl = WorkerManager(20, worker)
ctrl.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment