Last active
April 23, 2019 00:26
-
-
Save reiver-dev/83834a391e69941830a4640408568e8b to your computer and use it in GitHub Desktop.
Running processes and passing stdio fds over unix sockets, might be usefult for sidecar containers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
from array import array | |
from socket import (socket, AF_UNIX, SOCK_SEQPACKET, | |
CMSG_LEN, SOL_SOCKET, SCM_RIGHTS) | |
from errno import EADDRINUSE | |
from asyncio import ( | |
Future, Task, AbstractEventLoop, | |
get_running_loop as _get_running_loop, set_event_loop, | |
create_subprocess_exec, | |
wait as async_wait, FIRST_COMPLETED | |
) | |
from asyncio.unix_events import ( # type: ignore | |
_UnixSelectorEventLoop as UnixLoop | |
) | |
from asyncio.runners import ( # type: ignore | |
_cancel_all_tasks as cancel_all_tasks | |
) | |
from asyncio.subprocess import Process | |
import logging | |
from subprocess import DEVNULL | |
from pathlib import Path | |
from argparse import ArgumentParser, ArgumentTypeError | |
import json | |
import struct | |
import signal | |
from contextlib import contextmanager | |
from typing import (Sequence, Set, Mapping, | |
Optional, Tuple, Any, Iterable) | |
DEFAULT_LIMIT = 2 ** 16 | |
AncData = Tuple[int, int, bytes] | |
Msg = Tuple[bytes, Tuple[AncData, ...], int, Any] | |
NAME = __name__ | |
if NAME == '__main__': | |
NAME = Path(__file__).stem | |
_log = logging.getLogger(NAME) | |
debug = _log.debug | |
def itob(value: int) -> bytes: | |
return struct.pack('<i', value) | |
def btoi(value: bytes) -> int: | |
return struct.unpack('<i', value)[0] | |
class Loop(UnixLoop, AbstractEventLoop): | |
async def sock_recvmsg(self, sock: socket, bufsize: int, | |
ancbufsize: int = 0) -> Msg: | |
if self._debug and sock.gettimeout() != 0: | |
raise ValueError("the socket must be non-blocking") | |
future = self.create_future() | |
self._sock_recvmsg(future, None, sock, bufsize, ancbufsize) | |
return await future | |
async def sock_sendmsg(self, sock: socket, | |
data: Iterable[bytes], | |
ancdata: Iterable[AncData] = ()) -> int: | |
if self._debug and sock.gettimeout() != 0: | |
raise ValueError("the socket must be non-blocking") | |
future = self.create_future() | |
self._sock_sendmsg(future, None, sock, data, ancdata) | |
return await future | |
def _sock_recvmsg(self, future: Future, registered_fd: Optional[int], | |
sock: socket, bufsize: int, ancbufsize: int = 0, | |
flags: int = 0): | |
if registered_fd is not None: | |
self.remove_reader(registered_fd) | |
if future.cancelled(): | |
return | |
try: | |
debug('recvmsg data=%r ancdata=%r', bufsize, ancbufsize) | |
result = sock.recvmsg(bufsize, ancbufsize) | |
except (BlockingIOError, InterruptedError): | |
debug('wouldblock') | |
fd = sock.fileno() | |
self.add_reader(fd, self._sock_recvmsg, future, fd, | |
sock, bufsize, ancbufsize, flags) | |
except Exception as exc: | |
future.set_exception(exc) | |
else: | |
future.set_result(result) | |
def _sock_sendmsg(self, future: Future, registered_fd: Optional[int], | |
sock: socket, data: Iterable[bytes], | |
ancdata: Iterable[AncData]): | |
if registered_fd is not None: | |
self.remove_writer(registered_fd) | |
if future.cancelled(): | |
return | |
try: | |
debug('sendmsg data=%r ancdata=%r', data, ancdata) | |
result = sock.sendmsg(data, ancdata) | |
except (BlockingIOError, InterruptedError): | |
debug('wouldblock') | |
fd = sock.fileno() | |
self.add_writer(fd, self._sock_sendmsg, future, | |
fd, sock, data, ancdata) | |
except Exception as exc: | |
future.set_exception(exc) | |
else: | |
future.set_result(result) | |
def get_running_loop() -> Loop: | |
return _get_running_loop() # type: ignore | |
async def sendfd(loop: Loop, sock: socket, | |
msg: Iterable[bytes], | |
fds: Iterable[int]) -> int: | |
fdsmsg = [] | |
if fds: | |
fdsmsg.append((SOL_SOCKET, SCM_RIGHTS, bytes(array('i', fds)))) | |
result = await loop.sock_sendmsg(sock, msg, fdsmsg) | |
return result | |
async def recvfd(loop: Loop, sock: socket, msglen: int, | |
maxfds: int = 3) -> Tuple[bytes, Sequence[int]]: | |
fds = array('i') | |
msg, ancdata, _flags, _addr = await loop.sock_recvmsg( | |
sock, msglen, 0 if not maxfds else CMSG_LEN(maxfds * fds.itemsize) | |
) | |
debug('recvfd msg=%r acndata=%r flags=%d addr=%s', | |
msg, ancdata, _flags, _addr) | |
for level, msgtype, data in ancdata: | |
if level == SOL_SOCKET and msgtype == SCM_RIGHTS: | |
fds.frombytes(data[:len(data) - len(data) % fds.itemsize]) | |
return msg, list(fds) | |
def env_var_pair(value): | |
try: | |
name, _, value = value.partition('=') | |
return (name, value) | |
except Exception: | |
raise ArgumentTypeError('{} is not value env var pair'.format( | |
value | |
)) | |
def setup_exec_arguments(parser: ArgumentParser): | |
parser.add_argument('-c', '--connect', type=Path, | |
metavar='PATH', | |
help='path to server socket') | |
parser.add_argument('-e', '--env', type=env_var_pair, | |
metavar='KEY=VAL', | |
help='environment variable pairs', | |
action='append') | |
parser.add_argument('-w', '--workdir', metavar='PATH', | |
help='working directory') | |
parser.add_argument('program', nargs='...', | |
help='program to execute') | |
def setup_serve_arguments(parser: ArgumentParser): | |
parser.add_argument('path', type=Path, help='path to unix socket') | |
def server_socket(path: Path) -> socket: | |
sock = socket(family=AF_UNIX, type=SOCK_SEQPACKET) | |
try: | |
sock.bind(os.fspath(path)) | |
except OSError as exc: | |
sock.close() | |
if exc.errno == EADDRINUSE: | |
msg = 'Address `{}` is already in use'.format(path) | |
raise OSError(EADDRINUSE, msg) from None | |
else: | |
raise | |
except Exception: | |
sock.close() | |
raise | |
sock.setblocking(False) | |
return sock | |
def client_socket(path: Path) -> socket: | |
sock = socket(family=AF_UNIX, type=SOCK_SEQPACKET) | |
try: | |
sock.connect(os.fspath(path)) | |
sock.setblocking(False) | |
return sock | |
except Exception: | |
sock.close() | |
raise | |
def gather_fds(fds: Sequence[int], items: Sequence[int]) -> Sequence[int]: | |
return list(map(lambda x: x[0], zip(fds, filter(lambda x: x >= 0, items)))) | |
async def handle_process(loop: Loop, sock: socket, process: Process) -> int: | |
pid = process.pid | |
inp = loop.create_task(loop.sock_recvmsg(sock, 4096)) | |
finish = loop.create_task(process.wait()) | |
done: Set[Task] | |
pending: Set[Task] | |
pending = {inp, finish} | |
while True: | |
debug('waiting for process=%d to finish', pid) | |
(done, pending) = await async_wait(pending, # type: ignore | |
return_when=FIRST_COMPLETED) | |
debug('event occured') | |
if finish in done: | |
inp.cancel() | |
retcode = finish.result() | |
debug('process finished pid=%d ret=%d', pid, retcode) | |
await loop.sock_sendmsg(sock, [itob(retcode)]) | |
return retcode | |
if inp in done: | |
debug('signal received') | |
msg = inp.result()[0] | |
if not msg: | |
debug('client disconnected') | |
process.kill() | |
return await finish | |
if len(msg) != 4: | |
raise ValueError('wrong signal data: ' + str(signal)) | |
sigval = btoi(msg) | |
debug('process signal request pid=%d sig=%d', pid, sigval) | |
process.send_signal(sigval) | |
inp = loop.create_task(loop.sock_recvmsg(sock, 4096)) | |
pending.add(inp) | |
async def handle_client(loop: Loop, sock: socket): | |
msg, fds = await recvfd(loop, sock, 4096, 3) | |
request = json.loads(msg.decode('utf-8')) | |
debug('request=%s fds=%s', request, fds) | |
argv = request['argv'] | |
env = request.get('env', None) | |
cwd = request.get('cwd', None) | |
streams = request.get('io', {}) | |
sin = streams.get('in', DEVNULL) | |
sout = streams.get('out', DEVNULL) | |
serr = streams.get('err', DEVNULL) | |
debug('process argv=%s env=%s cwd=%s', argv, env, cwd) | |
environ = None | |
if env: | |
env = env | |
environ = os.environ.copy() | |
for name, value in env.items(): | |
environ[name] = os.path.expandvars(value) | |
workdir = None | |
if cwd: | |
workdir = str(cwd) | |
sin, sout, serr = gather_fds(fds, [streams.get(n, DEVNULL) | |
for n in ('in', 'out', 'err')]) | |
try: | |
process = await create_subprocess_exec(*argv, | |
stdin=sin, | |
stdout=sout, | |
stderr=serr, | |
env=environ, | |
cwd=workdir) | |
except Exception as err: | |
result = {'success': False, | |
'message': str(err), | |
'errno': getattr(err, 'errno', None), | |
'pid': None} | |
await loop.sock_sendmsg( | |
sock, [json.dumps(result).encode('utf-8')] | |
) | |
return | |
finally: | |
for fd in (sin, sout, serr): | |
if fd >= 0: | |
os.close(fd) | |
pid = process.pid | |
debug('process pid=%d', pid) | |
await loop.sock_sendmsg( | |
sock, [json.dumps({'success': True, 'pid': pid}).encode('utf-8')] | |
) | |
try: | |
return await handle_process(loop, sock, process) | |
except Exception: | |
process.kill() | |
await process.wait() | |
raise | |
def sock_close_cb(sock: socket): | |
def finish(task: Task): | |
debug('closing socket sock=%r after task=%r', | |
sock, task) | |
sock.close() | |
return finish | |
async def accept(loop: Loop, server_sock: socket): | |
server_sock.listen() | |
while True: | |
client, addr = await loop.sock_accept(server_sock) | |
debug('connected sock=%r addr=%r', client, addr) | |
loop.create_task(handle_client(loop, client)) | |
def forwarded_signals() -> Iterable[int]: | |
return frozenset( | |
sig | |
for sig in | |
map(lambda name: getattr(signal, name, None), ( | |
'SIGABRT', | |
'SIGALRM', | |
'SIGBUS', | |
'SIGCHLD', | |
'SIGCLD', | |
'SIGCONT', | |
'SIGEMT', | |
'SIGFPE', | |
'SIGHUP', | |
'SIGILL', | |
'SIGINFO', | |
'SIGINT', | |
'SIGIO', | |
'SIGIOT', | |
'SIGKILL', | |
'SIGLOST', | |
'SIGPIPE', | |
'SIGPOLL', | |
'SIGPROF', | |
'SIGPWR', | |
'SIGQUIT', | |
'SIGSEGV', | |
'SIGSTKFLT', | |
'SIGSTOP', | |
'SIGTSTP', | |
'SIGSYS', | |
'SIGTERM', | |
'SIGTRAP', | |
'SIGTTIN', | |
'SIGTTOU', | |
'SIGUNUSED', | |
'SIGURG', | |
'SIGUSR1', | |
'SIGUSR2', | |
'SIGVTALRM', | |
'SIGXCPU', | |
'SIGXFSZ', | |
'SIGWINCH', | |
)) | |
if sig is not None and sig != 9 and sig != 19 | |
) | |
async def send_signal(loop, sock, sig): | |
await loop.sock_sendmsg(sock, [itob(sig)]) | |
@contextmanager | |
def setup_signal_forwarding(loop: Loop, sock: socket): | |
def forward_signal(sig): | |
loop.create_task(send_signal(loop, sock, sig)) | |
signals = forwarded_signals() | |
for sig in signals: | |
loop.add_signal_handler(sig, forward_signal, sig) | |
try: | |
yield | |
finally: | |
for sig in signals: | |
loop.remove_signal_handler(sig) | |
async def connect(loop: Loop, client_sock: socket, argv: Sequence[str], | |
env: Mapping[str, str] = None, cwd: str = None) -> int: | |
streams = { | |
'in': sys.stdin.fileno(), | |
'out': sys.stdout.fileno(), | |
'err': sys.stderr.fileno() | |
} | |
request = { | |
'argv': argv, | |
'env': env, | |
'cwd': cwd, | |
'io': streams, | |
} | |
debug('requesting process %s', request) | |
await sendfd(loop, client_sock, | |
[json.dumps(request).encode('utf-8')], | |
list(streams.values())) | |
debug('waiting') | |
with setup_signal_forwarding(loop, client_sock): | |
msg = (await loop.sock_recvmsg(client_sock, 4096))[0] | |
procstart = json.loads(msg.decode('utf-8')) | |
debug('process response=%s', procstart) | |
if not procstart['success']: | |
return 127 | |
debug('waiting process to finish') | |
msg = (await loop.sock_recvmsg(client_sock, 4096))[0] | |
if not msg or len(msg) != 4: | |
_log.error('invalid response=%r', msg) | |
return 127 | |
retcode = btoi(msg) | |
debug('return code %d', retcode) | |
return retcode | |
async def create_server(path: Path, loop: Loop = None): | |
sock = server_socket(path) | |
try: | |
if loop is None: | |
loop = get_running_loop() | |
await accept(loop, sock) | |
finally: | |
if sock: | |
sock.close() | |
os.unlink(path) | |
async def create_client(path: Path, argv: Sequence[str], | |
env: Mapping[str, str] = None, | |
cwd: str = None, | |
loop: Loop = None) -> int: | |
sock = client_socket(path) | |
try: | |
if loop is None: | |
loop = get_running_loop() | |
return await connect(loop, sock, argv, env, cwd) | |
finally: | |
if sock: | |
sock.close() | |
def serve(argv) -> int: | |
run_forever(create_server(argv.path)) | |
return 0 | |
def execute(argv) -> int: | |
if not argv.program: | |
return 0 | |
return run_forever(create_client( | |
argv.connect, argv.program, | |
None if argv.env is None else dict(argv.env), | |
argv.workdir | |
)) | |
def run_forever(main): | |
loop = Loop() | |
try: | |
set_event_loop(loop) | |
loop.set_debug(True) | |
return loop.run_until_complete(main) | |
finally: | |
try: | |
cancel_all_tasks(loop) | |
loop.run_until_complete(loop.shutdown_asyncgens()) | |
finally: | |
set_event_loop(None) | |
loop.close() | |
def main(argv: Sequence[str] = None): | |
parser = ArgumentParser(NAME) | |
parser.add_argument('-v', '--verbose', action='count', | |
default=0, | |
help='enable debug messages') | |
commands = parser.add_subparsers(help='commands') | |
cmd_server = commands.add_parser('serve', help='launch server') | |
cmd_server.set_defaults(command=serve) | |
cmd_executor = commands.add_parser('exec', help='execute command') | |
cmd_executor.set_defaults(command=execute) | |
setup_serve_arguments(cmd_server) | |
setup_exec_arguments(cmd_executor) | |
params = parser.parse_args() | |
level = logging.DEBUG if params.verbose > 0 else logging.WARN | |
logging.basicConfig(level=level, stream=sys.stderr) | |
debug('args=%s', params) | |
return params.command(params) | |
if __name__ == '__main__': | |
sys.exit(main(sys.argv)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment