Skip to content

Instantly share code, notes, and snippets.

@umurgdk
Last active November 7, 2024 10:15
Show Gist options
  • Save umurgdk/f054d5724478b18e2afbf13c2b08cd0e to your computer and use it in GitHub Desktop.
Save umurgdk/f054d5724478b18e2afbf13c2b08cd0e to your computer and use it in GitHub Desktop.
Sharing file descriptors over unix domain sockets in zig
const std = @import("std");
const posix = std.posix;
const c = std.c;
pub const SCM_MAX_FD = 253;
// Taken from https://git.musl-libc.org/cgit/musl/tree/include/sys/socket.h
pub const SCM_RIGHTS = 0x01;
pub const SCM_CREDENTIALS = 0x02;
const cmsghdr = struct {
cmsg_len: usize,
cmsg_level: c_int,
cmsg_type: c_int,
};
pub fn recvFd(sock: posix.fd_t, msg_buff: []u8, msg_len: *usize, fd_buff: []posix.fd_t, fd_len: *usize) !void {
const cmsg_size = comptime CMSG_SPACE(@sizeOf(posix.fd_t) * SCM_MAX_FD);
var ctrl_buf align(@alignOf(cmsghdr)) = std.mem.zeroes([cmsg_size]u8);
var msg = std.mem.zeroes(posix.msghdr);
var iov = posix.iovec{ .base = msg_buff.ptr, .len = msg_buff.len };
msg.name = null;
msg.namelen = 0;
msg.iov = @ptrCast(&iov);
msg.iovlen = 1;
msg.control = &ctrl_buf;
msg.controllen = ctrl_buf.len;
const rc = c.recvmsg(sock, &msg, 0);
if (rc < 0) {
return error.RecvFailed;
}
std.debug.assert(msg.iovlen >= 1);
msg_len.* = msg.iov[0].len;
fd_len.* = 0;
var next_cmsg = CMSG_FIRSTHDR(&msg);
while (next_cmsg) |cmsg| {
if (cmsg.cmsg_level != c.SOL.SOCKET or cmsg.cmsg_type != SCM_RIGHTS) {
return error.InvalidFdControlMessage;
}
if (fd_len.* < fd_buff.len - 1) {
fd_buff[fd_len.*] = CMSG_DATA(posix.fd_t, cmsg).*;
fd_len.* += 1;
} else {
posix.close(CMSG_DATA(posix.fd_t, cmsg).*);
}
next_cmsg = CMSG_NXTHDR(&msg, cmsg);
}
}
pub fn sendFd(sock: posix.fd_t, fd: posix.fd_t, data: []const u8) !void {
std.debug.assert(data.len > 0);
var ctrl_buf = std.mem.zeroes([CMSG_SPACE(@sizeOf(posix.fd_t))]u8);
var msg = std.mem.zeroes(posix.msghdr_const);
var iov = posix.iovec{ .len = data.len, .base = @constCast(data.ptr) };
msg.name = null;
msg.namelen = 0;
msg.iov = @ptrCast(&iov);
msg.iovlen = 1;
msg.controllen = ctrl_buf.len;
msg.control = &ctrl_buf;
const cmsg = CMSG_FIRSTHDR(&msg) orelse unreachable;
cmsg.cmsg_level = c.SOL.SOCKET;
cmsg.cmsg_type = SCM_RIGHTS;
cmsg.cmsg_len = CMSG_LEN(@sizeOf(c_int));
CMSG_DATA(posix.fd_t, cmsg).* = fd;
_ = try posix.sendmsg(sock, &msg, 0);
}
inline fn CMSG_ALIGN(len: usize) usize {
return std.mem.alignForward(usize, len, @sizeOf(usize));
}
inline fn CMSG_LEN(len: usize) usize {
return CMSG_ALIGN(@sizeOf(cmsghdr)) + len;
}
inline fn CMSG_SPACE(len: usize) usize {
return CMSG_ALIGN(len) + CMSG_ALIGN(@sizeOf(cmsghdr));
}
inline fn CMSG_DATA(comptime T: type, cmsg: *cmsghdr) *T {
var byte_ptr: [*]u8 = @ptrCast(cmsg);
return @ptrCast(@alignCast(&byte_ptr[@sizeOf(cmsghdr)]));
}
inline fn CMSG_FIRSTHDR(hdr: anytype) ?*cmsghdr {
std.debug.assert(hdr.controllen >= @sizeOf(cmsghdr));
const raw_ptr: *u8 = @constCast(@ptrCast(hdr.control orelse unreachable));
return @ptrCast(@alignCast(raw_ptr));
}
inline fn CMSG_NXTHDR(mhdr: *posix.msghdr, cmsg: *cmsghdr) ?*cmsghdr {
const mhdr_end = @intFromPtr(__MHDR_END(mhdr));
const cmsg_start = @intFromPtr(cmsg);
if (cmsg.cmsg_len < @sizeOf(cmsghdr) or __CMSG_LEN(cmsg) + @sizeOf(cmsghdr) >= mhdr_end - cmsg_start) {
return null;
}
return @ptrCast(@alignCast(__CMSG_NEXT(cmsg)));
}
inline fn __CMSG_LEN(cmsg: *cmsghdr) usize {
return std.mem.alignForward(usize, @intCast(cmsg.cmsg_len), @sizeOf(c_long));
}
inline fn __CMSG_NEXT(cmsg: *cmsghdr) *u8 {
const cmsg_ptr: [*]u8 = @ptrCast(cmsg);
return &cmsg_ptr[__CMSG_LEN(cmsg)];
}
inline fn __MHDR_END(mhdr: *posix.msghdr) *u8 {
const ctrl_ptr: [*]u8 = @ptrCast(mhdr.control orelse unreachable);
return &ctrl_ptr[mhdr.controllen];
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment