Skip to content

Instantly share code, notes, and snippets.

@mchapman87501
Last active August 6, 2019 00:36
Show Gist options
  • Save mchapman87501/e1e369c628125e7b300923630cd600be to your computer and use it in GitHub Desktop.
Save mchapman87501/e1e369c628125e7b300923630cd600be to your computer and use it in GitHub Desktop.
Transfer file descriptors across a socket - python, unix
#!/usr/bin/env python3
# I finally read the documentation for socket.sendmsg, and there it
# was: How to send a list of file descriptors over an AF_UNIX socket:
# https://docs.python.org/3/library/socket.html?highlight=sendmsg#socket.socket.sendmsg
# https://docs.python.org/3/library/socket.html?highlight=sendmsg#socket.socket.recvmsg
import array
import socket
import typing as tp
def send_fds(sock: socket.socket, fds: tp.Iterable[int]) -> None:
msg = b""
fd_array = array.array("i", fds).tobytes()
ancillary = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fd_array)]
msg_bytes_sent = sock.sendmsg([msg], ancillary)
if msg_bytes_sent != len(msg):
raise RuntimeError(f"Sent {msg_bytes_sent} of expected {len(msg)}")
def recv_fds(sock: socket.socket, maxfds: int) -> tp.List[int]:
msglen = 256
fds = array.array("i")
msg, ancdata, flags, addr = sock.recvmsg(
msglen, socket.CMSG_LEN(maxfds * fds.itemsize))
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if ((cmsg_level == socket.SOL_SOCKET)
and (cmsg_type == socket.SCM_RIGHTS)):
# Append data, ignoring any truncated integers at the end.
data_len = len(cmsg_data)
i_end = data_len - (data_len % fds.itemsize)
fds.frombytes(cmsg_data[:i_end])
return list(fds)
#!/usr/bin/env python3
import concurrent.futures as futures
import fdio
import os
from pathlib import Path
import socket
import tempfile
import typing as tp
def _xfer_fds(sent_fds: tp.List[int]) -> tp.List[int]:
socks = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
def do_send() -> None:
fdio.send_fds(socks[0], sent_fds)
def do_recv() -> tp.List[int]:
return fdio.recv_fds(socks[1], len(sent_fds))
pool = futures.ThreadPoolExecutor(max_workers=2)
pool.submit(do_send)
future = pool.submit(do_recv)
recvd_fds = future.result(timeout=5)
assert len(recvd_fds) == len(sent_fds)
return recvd_fds
def test_xfer_fds() -> None:
"""Test transferring file descriptors."""
ntf = tempfile.NamedTemporaryFile
with ntf() as outf1, ntf() as outf2:
fd1 = outf1.fileno()
fd2 = outf2.fileno()
sent_fds = [fd1, fd2]
recvd_fds = _xfer_fds(sent_fds)
# Verify that the received descriptors are associated
# with the same filesystem entries as the sent descriptors.
for sfd, rfd in zip(sent_fds, recvd_fds):
sinode = os.stat(sfd).st_ino
rinode = os.stat(rfd).st_ino
assert sinode == rinode
def test_xfer_sockets() -> None:
"""Test transferring sockets."""
sent_socks = list(socket.socketpair())
sent_fds = [s.fileno() for s in sent_socks]
recvd_fds = _xfer_fds(sent_fds)
recvd_socks = [
socket.fromfd(rfd, socket.AF_UNIX, socket.SOCK_STREAM)
for rfd in recvd_fds]
# Verify that the sockets are usable.
should_be_connected = [
(sent_socks[0], recvd_socks[1]),
(sent_socks[1], recvd_socks[0])
]
pool = futures.ThreadPoolExecutor(max_workers=2)
for i, (ssend, srecv) in enumerate(should_be_connected):
msg: bytes = f"Send / Receive Test {i}".encode("utf8")
def do_send() -> None:
ssend.send(msg)
def do_recv() -> bytes:
return srecv.recv(len(msg))
pool.submit(do_send)
recvd_msg = pool.submit(do_recv).result()
assert recvd_msg == msg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment