Created
March 6, 2023 14:15
-
-
Save tonyseek/963896eb5aac6198f86c29cb746f1459 to your computer and use it in GitHub Desktop.
Elect a leader in multiprocess program with unnamed domain socket
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import signal | |
import collections | |
import logging | |
import resource | |
logging.basicConfig(level=logging.DEBUG) | |
class WorkerManager: | |
"""Spawns and monitors worker processes.""" | |
logger = logging.getLogger(__name__) | |
def __init__(self, worker_num, worker_func): | |
self.ctrl_pid = os.getpid() | |
self.ctrl_stopped = False | |
self.worker_num = worker_num | |
self.worker_pids = set() | |
self.worker_func = worker_func | |
self.worker_exit_queue = collections.deque() | |
self.worker_exit_updated = False | |
self.setup_signals() | |
def setup_signals(self): | |
signal.signal(signal.SIGCHLD, self.handle_sigchild) | |
signal.signal(signal.SIGUSR1, self.handle_sigusr1) | |
signal.signal(signal.SIGINT, self.handle_cleanup) | |
signal.signal(signal.SIGTERM, self.handle_cleanup) | |
def reset_signals(self): | |
signal.signal(signal.SIGCHLD, signal.SIG_DFL) | |
signal.signal(signal.SIGUSR1, signal.SIG_DFL) | |
signal.signal(signal.SIGINT, signal.SIG_DFL) | |
signal.signal(signal.SIGTERM, signal.SIG_DFL) | |
def run(self): | |
for idx in range(self.worker_num): | |
if not self.spawn_worker(f'NEW-{idx}'): | |
return | |
self.print_tree() | |
while not self.ctrl_stopped: | |
try: | |
signal.pause() | |
while len(self.worker_exit_queue) > 0: | |
(pid, status) = self.worker_exit_queue.popleft() | |
self.worker_pids.remove(pid) | |
self.logger.info( | |
'Found worker died (pid=%d, status=%d)', pid, status) | |
finally: | |
self.worker_exit_updated = False | |
if len(self.worker_pids) < self.worker_num: | |
for idx in range(len(self.worker_pids), self.worker_num): | |
if not self.spawn_worker(f'SUPPLY-{idx}'): | |
return | |
self.print_tree() | |
while True: | |
try: | |
(pid, status) = os.waitpid(-1, 0) | |
except ChildProcessError: | |
break | |
if pid <= 0: | |
break | |
self.logger.info( | |
'Cleanup worker (pid=%d, status=%d)', pid, status) | |
def print_tree(self): | |
worker_pids = sorted(self.worker_pids) | |
print(f'CTRL: {self.ctrl_pid}') | |
for idx, pid in enumerate(worker_pids[:-1]): | |
print(f'├─ {idx:2}: {pid}') | |
for pid in worker_pids[-1:]: | |
print(f'└─ {len(worker_pids) - 1:2}: {pid}') | |
def spawn_worker(self, annotation): | |
if os.getpid() != self.ctrl_pid: | |
raise RuntimeError('Please report a bug') | |
worker_pid = os.fork() | |
if worker_pid == 0: | |
self.reset_signals() | |
self.worker_func(self.ctrl_pid) | |
return False | |
else: | |
self.logger.info( | |
'Spawned worker (pid=%d) # %s', worker_pid, annotation) | |
self.worker_pids.add(worker_pid) | |
return True | |
def handle_sigchild(self, signum, frame): | |
del frame | |
if signum != signal.SIGCHLD: | |
return | |
while True: | |
try: | |
(pid, status) = os.waitpid(-1, os.WNOHANG) | |
except ChildProcessError: | |
break | |
if pid <= 0: | |
break | |
self.worker_exit_queue.append((pid, status)) | |
self.worker_exit_updated = True | |
def handle_sigusr1(self, signum, frame): | |
del frame | |
if signum != signal.SIGUSR1: | |
return | |
self.print_tree() | |
def handle_cleanup(self, signum, frame): | |
del frame | |
if signum not in (signal.SIGTERM, signal.SIGINT): | |
return | |
for pid in self.worker_pids: | |
try: | |
os.kill(pid, signal.SIGTERM) | |
except ProcessLookupError: | |
continue | |
self.ctrl_stopped = True | |
def worker(ctrl_pid): | |
import socket | |
import errno | |
import logging | |
import enum | |
import time | |
import select | |
logger = logging.getLogger(f'worker-{ctrl_pid}:{os.getpid()}') | |
class Role(enum.Enum): | |
LEADER = 1 | |
FOLLOWER = 2 | |
VOTING = 3 | |
# Uses the connection state machine of TCP to elect a leader | |
addr = "\0foo-bar-baz" | |
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | |
role = Role.VOTING | |
while True: | |
if role == Role.VOTING: | |
try: | |
sock.bind(addr) # unnamed domain socket (linux-only) | |
except OSError as e: | |
if e.errno == errno.EADDRINUSE: | |
role = Role.FOLLOWER | |
else: | |
raise | |
else: | |
sock.listen(1) | |
role = Role.LEADER | |
if role == Role.LEADER: | |
logger.info("I am the leader now") | |
while True: | |
logger.debug("I am doing leader's work") | |
time.sleep(10) # Finish leader's work | |
else: | |
try: | |
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | |
sock.connect(addr) | |
except ConnectionRefusedError: | |
role = Role.VOTING # Try to be a new leader | |
logger.debug("Try to be a new leader since we cannot connect") | |
continue | |
while True: | |
fds, _, _ = select.select([sock], [], [], 30) | |
if fds and fds[0] == sock: | |
try: | |
rx = sock.recv(1) | |
except ConnectionResetError: | |
rx = '' | |
if rx == '': # EOF | |
role = Role.VOTING # Try to be a new leader | |
logger.debug("Try to be a new leader since EOF") | |
break | |
if __name__ == '__main__': | |
# prevent from fork-bomb on debugging | |
resource.setrlimit(resource.RLIMIT_NPROC, (100, 100)) | |
ctrl = WorkerManager(20, worker) | |
ctrl.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment