Skip to content

Instantly share code, notes, and snippets.

@lithdew
Created October 13, 2020 08:50
Show Gist options
  • Select an option

  • Save lithdew/0fe0eb50d9a56582be65cf9e2d03c91f to your computer and use it in GitHub Desktop.

Select an option

Save lithdew/0fe0eb50d9a56582be65cf9e2d03c91f to your computer and use it in GitHub Desktop.
zig (windows): afd poll w/ WaitForSingleObjectEx
const std = @import("std");
const os = std.os;
const windows = os.windows;
const assert = std.debug.assert;
pub const AFD_POLL_HANDLE_INFO = extern struct {
Handle: windows.HANDLE,
Events: windows.ULONG,
Status: windows.NTSTATUS,
};
pub const AFD_POLL_INFO = extern struct {
Timeout: windows.LARGE_INTEGER,
NumberOfHandles: windows.ULONG,
Exclusive: windows.ULONG,
// followed by an array of `NumberOfHandles` AFD_POLL_HANDLE_INFO
// Handles[]: AFD_POLL_HANDLE_INFO,
};
pub const AFD_NO_FAST_IO = 0x00000001;
pub const AFD_OVERLAPPED = 0x00000002;
pub const AFD_IMMEDIATE = 0x00000004;
pub const AFD_POLL_RECEIVE_BIT = 0;
pub const AFD_POLL_RECEIVE = 1 << AFD_POLL_RECEIVE_BIT;
pub const AFD_POLL_RECEIVE_EXPEDITED_BIT = 1;
pub const AFD_POLL_RECEIVE_EXPEDITED = 1 << AFD_POLL_RECEIVE_EXPEDITED_BIT;
pub const AFD_POLL_SEND_BIT = 2;
pub const AFD_POLL_SEND = 1 << AFD_POLL_SEND_BIT;
pub const AFD_POLL_DISCONNECT_BIT = 3;
pub const AFD_POLL_DISCONNECT = 1 << AFD_POLL_DISCONNECT_BIT;
pub const AFD_POLL_ABORT_BIT = 4;
pub const AFD_POLL_ABORT = 1 << AFD_POLL_ABORT_BIT;
pub const AFD_POLL_LOCAL_CLOSE_BIT = 5;
pub const AFD_POLL_LOCAL_CLOSE = 1 << AFD_POLL_LOCAL_CLOSE_BIT;
pub const AFD_POLL_CONNECT_BIT = 6;
pub const AFD_POLL_CONNECT = 1 << AFD_POLL_CONNECT_BIT;
pub const AFD_POLL_ACCEPT_BIT = 7;
pub const AFD_POLL_ACCEPT = 1 << AFD_POLL_ACCEPT_BIT;
pub const AFD_POLL_CONNECT_FAIL_BIT = 8;
pub const AFD_POLL_CONNECT_FAIL = 1 << AFD_POLL_CONNECT_FAIL_BIT;
pub const AFD_POLL_QOS_BIT = 9;
pub const AFD_POLL_QOS = 1 << AFD_POLL_QOS_BIT;
pub const AFD_POLL_GROUP_QOS_BIT = 10;
pub const AFD_POLL_GROUP_QOS = 1 << AFD_POLL_GROUP_QOS_BIT;
pub const AFD_NUM_POLL_EVENTS = 11;
pub const AFD_POLL_ALL = (1 << AFD_NUM_POLL_EVENTS) - 1;
pub const AFD_RECV_DATAGRAM_INFO = extern struct {
BufferArray: windows.LPWSABUF,
BufferCount: windows.ULONG,
AfdFlags: windows.ULONG,
TdiFlags: windows.ULONG,
Address: *windows.sockaddr,
AddressLength: *usize,
};
pub const AFD_RECV_INFO = extern struct {
BufferArray: windows.LPWSABUF,
BufferCount: windows.ULONG,
AfdFlags: windows.ULONG,
TdiFlags: windows.ULONG,
};
pub fn AFD_CONTROL_CODE(function: u10, method: windows.TransferType) windows.DWORD {
return (@as(windows.DWORD, windows.FILE_DEVICE_NETWORK) << 12) |
(@as(windows.DWORD, function) << 2) |
@enumToInt(method);
}
pub const AFD_RECEIVE = 5;
pub const AFD_RECEIVE_DATAGRAM = 6;
pub const AFD_POLL = 9;
pub const IOCTL_AFD_RECEIVE = AFD_CONTROL_CODE(AFD_RECEIVE, .METHOD_NEITHER);
pub const IOCTL_AFD_RECEIVE_DATAGRAM = AFD_CONTROL_CODE(AFD_RECEIVE_DATAGRAM, .METHOD_NEITHER);
pub const IOCTL_AFD_POLL = AFD_CONTROL_CODE(AFD_POLL, .METHOD_BUFFERED);
fn CreateAfdHandle() !windows.HANDLE {
var afdName: [windows.PATH_MAX_WIDE]u16 = undefined;
const afdNameLength = @sizeOf(u16) * try std.unicode.utf8ToUtf16Le(afdName[0..], "\\Device\\Afd");
var afdNameUS = windows.UNICODE_STRING{
.Length = @intCast(u16, afdNameLength),
.MaximumLength = @sizeOf(@TypeOf(afdName)),
.Buffer = &afdName,
};
var afd: windows.HANDLE = undefined;
var objectAttributes: windows.OBJECT_ATTRIBUTES = .{
.Length = @sizeOf(windows.OBJECT_ATTRIBUTES),
.RootDirectory = null,
.Attributes = 0,
.ObjectName = &afdNameUS,
.SecurityDescriptor = null,
.SecurityQualityOfService = null,
};
var ioStatusBlock: windows.IO_STATUS_BLOCK = undefined;
const status = windows.ntdll.NtCreateFile(&afd, windows.SYNCHRONIZE, &objectAttributes, &ioStatusBlock, null, 0, windows.FILE_SHARE_READ | windows.FILE_SHARE_WRITE, windows.FILE_OPEN, 0, null, 0);
return switch (status) {
.SUCCESS => afd,
else => windows.unexpectedStatus(status),
};
}
fn getBaseSocket(sock: windows.ws2_32.SOCKET) !windows.ws2_32.SOCKET {
var base_socket: windows.ws2_32.SOCKET = undefined;
const bytes = try windows.WSAIoctl(
sock,
windows.ws2_32.SIO_BASE_HANDLE,
&[_]u8{},
@ptrCast([*]u8, &base_socket)[0..@sizeOf(windows.ws2_32.SOCKET)],
null,
null,
);
assert(bytes == @sizeOf(windows.ws2_32.SOCKET));
return base_socket;
}
fn waitFor(sock: windows.ws2_32.SOCKET, events: windows.ULONG, timeout: i64) !windows.ULONG {
const afd = try CreateAfdHandle();
defer windows.CloseHandle(afd);
const base_socket = try getBaseSocket(sock);
var ioctl_data: extern struct {
AfdPollInfo: AFD_POLL_INFO,
Handles: [1]AFD_POLL_HANDLE_INFO,
} = .{
.AfdPollInfo = .{
.NumberOfHandles = 1,
.Timeout = std.math.maxInt(i64),
.Exclusive = 0,
},
.Handles = .{
.{
.Handle = @ptrCast(windows.HANDLE, base_socket),
.Status = .SUCCESS,
.Events = events,
},
},
};
const event = try windows.CreateEventEx(null, "afd_poll", 0, windows.SYNCHRONIZE | windows.EVENT_MODIFY_STATE);
defer windows.CloseHandle(event);
var ioStatusBlock: windows.IO_STATUS_BLOCK = undefined;
var status = windows.ntdll.NtDeviceIoControlFile(
afd,
event,
null, //apc,
undefined,
&ioStatusBlock,
IOCTL_AFD_POLL,
&ioctl_data,
@sizeOf(@TypeOf(ioctl_data)),
&ioctl_data,
@sizeOf(@TypeOf(ioctl_data)),
);
if (status == .PENDING) {
try windows.WaitForSingleObjectEx(event, windows.INFINITE, true);
status = ioStatusBlock.u.Status;
}
switch (status) {
.SUCCESS => {},
else => return windows.unexpectedStatus(status),
}
assert(ioStatusBlock.Information >= @sizeOf(AFD_POLL_INFO));
assert(ioStatusBlock.Information <= @sizeOf(@TypeOf(ioctl_data)));
const nRetHandles = (ioStatusBlock.Information - @sizeOf(AFD_POLL_INFO)) / @sizeOf(AFD_POLL_HANDLE_INFO);
const RetHandles = ioctl_data.Handles[0..nRetHandles];
std.debug.warn("RetHandles={}\n", .{RetHandles[0]});
return RetHandles[0].Events;
}
pub fn ReadFile(fd: os.fd_t, buf: []u8, overlapped: *windows.OVERLAPPED) !void {
const len = std.math.cast(windows.DWORD, buf.len) catch std.math.maxInt(windows.DWORD);
const success = windows.kernel32.ReadFile(fd, buf.ptr, len, null, overlapped);
if (success == windows.FALSE) {
return switch (windows.kernel32.GetLastError()) {
.IO_PENDING => error.WouldBlock,
.OPERATION_ABORTED => error.OperationAborted,
.BROKEN_PIPE => error.BrokenPipe,
.HANDLE_EOF, .NETNAME_DELETED => {},
else => |err| windows.unexpectedError(err),
};
}
}
pub fn main() !void {
_ = try windows.WSAStartup(2, 2);
defer windows.WSACleanup() catch @panic("unable to WSACleanup()");
const dest = try std.net.Address.parseIp("127.0.0.1", 9000);
const sock = try windows.WSASocketW(dest.any.family, os.SOCK_STREAM, os.IPPROTO_TCP, null, 0, 0);
defer windows.closesocket(sock) catch @panic("closesocket failed");
switch (windows.ws2_32.connect(sock, &dest.any, dest.getOsSockLen())) {
0 => {},
else => return windows.unexpectedWSAError(windows.ws2_32.WSAGetLastError()),
}
var buffer: [1024]u8 = undefined;
var overlapped: windows.OVERLAPPED = .{
.Internal = 0,
.InternalHigh = 0,
.Offset = 0,
.OffsetHigh = 0,
.hEvent = null,
};
std.debug.warn("got events={}\n", .{try waitFor(sock, AFD_POLL_RECEIVE, windows.INFINITE)});
ReadFile(sock, &buffer, &overlapped) catch |err| @panic(@errorName(err));
std.debug.print("Got text: {}", .{buffer[0..overlapped.InternalHigh]});
std.debug.warn("got events={}\n", .{try waitFor(sock, AFD_POLL_RECEIVE, windows.INFINITE)});
ReadFile(sock, &buffer, &overlapped) catch |err| @panic(@errorName(err));
std.debug.print("Got text: {}", .{buffer[0..overlapped.InternalHigh]});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment