-
-
Save daurnimator/699320cda828303671a21d15bb4a3753 to your computer and use it in GitHub Desktop.
simple http server compatible with wrk for linux using epoll
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
pub fn main() !void { | |
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; | |
const allocator = &gpa.allocator; | |
defer _ = gpa.deinit(); | |
var poller = try Poller.init(allocator); | |
defer poller.deinit(); | |
try Server.start(&poller, 12345); | |
while (true) { | |
poller.poll(); | |
} | |
} | |
const Poller = struct { | |
fd: std.os.fd_t, | |
allocator: *std.mem.Allocator, | |
fn init(allocator: *std.mem.Allocator) !Poller { | |
return Poller{ | |
.fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC), | |
.allocator = allocator, | |
}; | |
} | |
fn deinit(self: *Poller) void { | |
std.os.close(self.fd); | |
} | |
const Event = struct { | |
const Callback = struct { | |
onEventFn: fn(*Callback, Event) void, | |
}; | |
is_closable: bool, | |
is_readable: bool, | |
is_writable: bool, | |
}; | |
const Socket = struct { | |
poller: *Poller, | |
fd: std.os.socket_t, | |
callback: Event.Callback, | |
fn start(comptime Self: type, poller: *Poller, fd: std.os.socket_t) !*Self { | |
const Callback = struct { | |
fn onEvent(callback: *Event.Callback, event: Event) void { | |
const socket = @fieldParentPtr(Socket, "callback", callback); | |
const self = @fieldParentPtr(Self, "socket", socket); | |
handleEvent(self, event) catch { | |
self.onClose(); | |
std.os.close(self.socket.fd); | |
self.socket.poller.allocator.destroy(self); | |
}; | |
} | |
fn handleEvent(self: *Self, event: Event) !void { | |
if (event.is_closable) | |
return error.Closed; | |
if (event.is_readable) | |
try self.onRead(); | |
if (event.is_writable) | |
try self.onWrite(); | |
} | |
}; | |
const self = try poller.allocator.create(Self); | |
errdefer poller.allocator.destroy(self); | |
self.socket = .{ | |
.poller = poller, | |
.fd = fd, | |
.callback = .{ | |
.onEventFn = Callback.onEvent, | |
}, | |
}; | |
// Register with edge-triggering (EPOLLET) so that it re-arms the events when we do IO. | |
// Saves doing another epoll_ctl() syscall to rearm if using level/edge-trigerring. | |
try std.os.epoll_ctl(poller.fd, std.os.EPOLL_CTL_ADD, fd, &std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLOUT | std.os.EPOLLET | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(&self.socket.callback) }, | |
}); | |
return self; | |
} | |
}; | |
fn poll(self: *Poller) void { | |
var events: [128]std.os.epoll_event = undefined; | |
const events_found = std.os.epoll_wait(self.fd, &events, -1); | |
if (events_found == 0) | |
return; | |
for (events[0..events_found]) |ev| { | |
const callback = @intToPtr(*Event.Callback, ev.data.ptr); | |
(callback.onEventFn)(callback, Event{ | |
.is_closable = ev.events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0, | |
.is_readable = ev.events & std.os.EPOLLIN != 0, | |
.is_writable = ev.events & std.os.EPOLLOUT != 0, | |
}); | |
} | |
} | |
}; | |
const Server = struct { | |
socket: Poller.Socket, | |
port: u16, | |
const SOCK_FLAGS = std.os.SOCK_CLOEXEC | std.os.SOCK_NONBLOCK; | |
fn start(poller: *Poller, comptime port: u16) !void { | |
const fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | SOCK_FLAGS, std.os.IPPROTO_TCP); | |
errdefer std.os.close(fd); | |
// Bind the socket to the port on the local address | |
const address = "127.0.0.1"; | |
var addr = comptime std.net.Address.parseIp(address, port) catch unreachable; | |
try std.os.setsockopt(fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(fd, 128); | |
const self = try Poller.Socket.start(Server, poller, fd); | |
self.port = port; | |
std.debug.warn("Listening on {}:{}", .{address, port}); | |
} | |
fn onClose(self: *Server) void { | |
std.debug.warn("server shutdown for port: {}", .{self.port}); | |
} | |
fn onWrite(self: *Server) !void { | |
std.debug.panic("server shouldn't writable", .{}); | |
} | |
fn onRead(self: *Server) !void { | |
while (true) { | |
const client_fd = std.os.accept(self.socket.fd, null, null, SOCK_FLAGS) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
Client.start(self.socket.poller, client_fd) catch |err| { | |
std.debug.warn("Failed to spawn a client: {}\n", .{err}); | |
continue; | |
}; | |
} | |
} | |
}; | |
const Client = struct { | |
socket: Poller.Socket, | |
send_bytes: usize, | |
send_partial: usize, | |
recv_bytes: usize, | |
recv_buffer: [4096]u8, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn start(poller: *Poller, fd: std.os.socket_t) !void { | |
errdefer std.os.close(fd); | |
// Enable TCP-NoDelay to send the http responses as fast as possible. | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
const self = try Poller.Socket.start(Client, poller, fd); | |
self.send_bytes = 0; | |
self.send_partial = 0; | |
self.recv_bytes = 0; | |
self.recv_buffer = undefined; | |
} | |
fn onClose(self: *Client) void { | |
// Do nothing | |
} | |
fn onRead(self: *Client) !void { | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
// Try to parse and consume a request in the request_buffer | |
// by matching everything until the end of an HTTP request with no body | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
// If found, count that as a parsed request | |
// and have the writer write the static HTTP response (eventually). | |
self.send_bytes += HTTP_RESPONSE.len; | |
continue; | |
} | |
// A complete wasn't parsed yet. | |
// Try to read more data into the buffer and try again. | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
// If the read would normally block, | |
// then we have to wait for the socket to be readable in the future to try again. | |
const bytes_read = std.os.read(self.socket.fd, readable_buffer) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
// 0 bytes read indicates that the socket can no longer read any data. | |
self.recv_bytes += bytes_read; | |
if (bytes_read == 0) | |
return error.EndOfStream; | |
} | |
} | |
fn onWrite(self: *Client) !void { | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
while (self.send_bytes > 0) { | |
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial | |
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
const iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len); | |
const writable_buffer = iov_base[0..iov_len]; | |
// Perform the actual write. | |
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing. | |
const bytes_written = std.os.sendto( | |
self.socket.fd, | |
writable_buffer, | |
std.os.MSG_NOSIGNAL, | |
null, | |
@as(std.os.socklen_t, 0), | |
) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.send_bytes -= bytes_written; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
} | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
pub fn main() !void { | |
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; | |
defer _ = gpa.deinit(); | |
const allocator = &gpa.allocator; | |
const num_threads = 6; // std.math.max(1, std.Thread.cpuCount() catch 1); | |
const worker_fds = try allocator.alloc(std.os.fd_t, num_threads); | |
defer allocator.free(worker_fds); | |
for (worker_fds) |*worker_fd| | |
worker_fd.* = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC); | |
for (worker_fds[1..]) |worker_fd| | |
_ = try std.Thread.spawn(worker_fd, runWorker); | |
const server_fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
errdefer std.os.close(server_fd); | |
const port = 12345; | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(server_fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(server_fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(server_fd, 128); | |
var next_fd: usize = 0; | |
const epoll_fd = worker_fds[next_fd]; | |
var events: [256]std.os.epoll_event = undefined; | |
var server_event = std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLET | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = 0 }, | |
}; | |
try std.os.epoll_ctl( | |
epoll_fd, | |
std.os.EPOLL_CTL_ADD, | |
server_fd, | |
&server_event, | |
); | |
std.debug.warn("Listening on :{}", .{port}); | |
while (true) { | |
const found = std.os.epoll_wait(epoll_fd, &events, -1); | |
for (events[0..found]) |event| { | |
if (event.data.ptr != 0) { | |
Client.process(event); | |
continue; | |
} | |
if (event.events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0) | |
unreachable; | |
if (event.events & std.os.EPOLLIN == 0) | |
unreachable; | |
while (true) { | |
const client_fd = std.os.accept(server_fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) { | |
error.WouldBlock => break, | |
else => |e| return e, | |
}; | |
if (Client.start(allocator, worker_fds[next_fd], client_fd)) |_| { | |
next_fd += 1; | |
if (next_fd >= worker_fds.len) | |
next_fd = 0; | |
} else |_| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to start client: {}\n", .{client_fd}); | |
} | |
} | |
} | |
} | |
} | |
fn runWorker(epoll_fd: std.os.fd_t) void { | |
var events: [256]std.os.epoll_event = undefined; | |
while (true) { | |
const found = std.os.epoll_wait(epoll_fd, &events, -1); | |
for (events[0..found]) |event| | |
Client.process(event); | |
} | |
} | |
const Client = struct { | |
fd: std.os.socket_t, | |
allocator: *std.mem.Allocator, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn start(allocator: *std.mem.Allocator, epoll_fd: std.os.fd_t, fd: std.os.socket_t) !void { | |
const self = try allocator.create(Client); | |
errdefer allocator.destroy(self); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
self.* = .{ | |
.fd = fd, | |
.allocator = allocator, | |
}; | |
try std.os.epoll_ctl(epoll_fd, std.os.EPOLL_CTL_ADD, fd, &std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLOUT | std.os.EPOLLET | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(self) }, | |
}); | |
} | |
fn process(event: std.os.epoll_event) void { | |
const self = @intToPtr(*Client, event.data.ptr); | |
self.processEvent(event.events) catch { | |
std.os.close(self.fd); | |
self.allocator.destroy(self); | |
}; | |
} | |
fn processEvent(self: *Client, events: u32) !void { | |
if (events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0) | |
return error.Closed; | |
if (events & std.os.EPOLLIN != 0) | |
try self.processRead(); | |
if (events & std.os.EPOLLOUT != 0) | |
try self.processWrite(); | |
} | |
fn processRead(self: *Client) !void { | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
// Try to parse and consume a request in the request_buffer | |
// by matching everything until the end of an HTTP request with no body | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
// If found, count that as a parsed request | |
// and have the writer write the static HTTP response (eventually). | |
self.send_bytes += HTTP_RESPONSE.len; | |
continue; | |
} | |
// A complete wasn't parsed yet. | |
// Try to read more data into the buffer and try again. | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
// If the read would normally block, | |
// then we have to wait for the socket to be readable in the future to try again. | |
const bytes_read = std.os.read(self.fd, readable_buffer) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
// 0 bytes read indicates that the socket can no longer read any data. | |
self.recv_bytes += bytes_read; | |
if (bytes_read == 0) | |
return error.EndOfStream; | |
} | |
} | |
fn processWrite(self: *Client) !void { | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
while (self.send_bytes > 0) { | |
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial | |
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
const iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len - self.send_partial); | |
const writable_buffer = iov_base[0..iov_len]; | |
// Perform the actual write. | |
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing. | |
const bytes_written = std.os.sendto( | |
self.fd, | |
writable_buffer, | |
std.os.MSG_NOSIGNAL, | |
null, | |
@as(std.os.socklen_t, 0), | |
) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.send_bytes -= bytes_written; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
} | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
var poll_fd: std.os.fd_t = undefined; | |
var server_fd: std.os.socket_t = undefined; | |
var allocator: *std.mem.Allocator = undefined; | |
pub fn main() !void { | |
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; | |
allocator = &gpa.allocator; | |
defer _ = gpa.deinit(); | |
poll_fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC); | |
defer std.os.close(poll_fd); | |
server_fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
defer std.os.close(server_fd); | |
const port = 12345; | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(server_fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(server_fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(server_fd, 128); | |
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_ADD, server_fd, &std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLET | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(&server_fd) }, | |
}); | |
var threads = std.math.max(1, try std.Thread.cpuCount()); | |
while (threads > 1) : (threads -= 1) | |
_ = try std.Thread.spawn({}, runWorker); | |
std.debug.warn("Listening on :{}\n", .{port}); | |
runWorker({}); | |
} | |
fn runWorker(_: void) void { | |
var events: [256]std.os.epoll_event = undefined; | |
while (true) { | |
const found = std.os.epoll_wait(poll_fd, &events, -1); | |
if (found == 0) | |
continue; | |
for (events[0..found]) |event| { | |
const ptr = event.data.ptr; | |
const flags = event.events; | |
if (ptr == @ptrToInt(&server_fd)) { | |
Client.accept(flags) catch |e| std.debug.warn("failed to accept a client: {}\n", .{e}); | |
continue; | |
} | |
const client = @intToPtr(*Client, ptr); | |
client.process(flags) catch {}; | |
} | |
} | |
} | |
const Client = struct { | |
fd: std.os.socket_t, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn accept(flags: u32) !void { | |
while (true) { | |
const client_fd = std.os.accept(server_fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
errdefer std.os.close(client_fd); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(client_fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
const self = try allocator.create(Client); | |
self.* = Client{ .fd = client_fd }; | |
errdefer allocator.destroy(self); | |
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_ADD, client_fd, &std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(self) }, | |
}); | |
} | |
} | |
fn process(self: *Client, flags: u32) !void { | |
errdefer { | |
std.os.close(self.fd); | |
allocator.destroy(self); | |
} | |
if (flags & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0) | |
return error.Closed; | |
var written = false; | |
if ((flags & std.os.EPOLLOUT != 0) and (self.send_bytes > 0)) { | |
written = true; | |
try self.processWrite(); | |
} | |
if (flags & std.os.EPOLLIN != 0) | |
try self.processRead(); | |
if (!written and (flags & std.os.EPOLLOUT != 0) and (self.send_bytes > 0)) | |
try self.processWrite(); | |
var events: u32 = std.os.EPOLLIN | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP; | |
if (self.send_bytes > 0) | |
events |= std.os.EPOLLOUT; | |
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_MOD, self.fd, &std.os.epoll_event{ | |
.events = events, | |
.data = .{ .ptr = @ptrToInt(self) }, | |
}); | |
} | |
fn processRead(self: *Client) !void { | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
self.send_bytes += HTTP_RESPONSE.len; | |
continue; | |
} | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
const bytes_read = std.os.read(self.fd, readable_buffer) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.recv_bytes += bytes_read; | |
if (bytes_read == 0) | |
return error.EndOfStream; | |
} | |
} | |
fn processWrite(self: *Client) !void { | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
while (self.send_bytes > 0) { | |
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial | |
const send_bytes = self.send_bytes; | |
if (self.send_partial > RESPONSE_CHUNK.len) | |
std.debug.panic("invalid send_partial={} chunk={}\n", .{self.send_partial, RESPONSE_CHUNK.len}); | |
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
const iov_len = std.math.min(send_bytes, RESPONSE_CHUNK.len - self.send_partial); | |
const writable_buffer = iov_base[0..iov_len]; | |
// Perform the actual write. | |
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing. | |
const bytes_written = std.os.sendto( | |
self.fd, | |
writable_buffer, | |
std.os.MSG_NOSIGNAL, | |
null, | |
@as(std.os.socklen_t, 0), | |
) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
self.send_bytes -= bytes_written; | |
} | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const ThreadPool = @import("./ThreadPoolIO.zig"); | |
pub fn main() !void { | |
var pool = try ThreadPool.init(.{ .max_threads = 6 }); | |
defer pool.deinit(); | |
var server: Server = undefined; | |
try server.init(&pool, 12345); | |
var event = std.StaticResetEvent{}; | |
event.wait(); | |
} | |
const Poller = struct { | |
fd: std.os.fd_t, | |
gpa: std.heap.GeneralPurposeAllocator(.{}) = .{}, | |
fn init(self: *Poller) !void { | |
const fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC); | |
self.* = .{ .fd = fd }; | |
} | |
fn deinit(self: *Poller) void { | |
std.os.close(self.fd); | |
_ = self.gpa.deinit(); | |
} | |
fn getAllocator(self: *Poller) *std.mem.Allocator { | |
return &self.gpa.allocator; | |
} | |
}; | |
const Server = struct { | |
fd: std.os.socket_t, | |
pool: *ThreadPool, | |
io_runnable: ThreadPool.IoRunnable, | |
gpa: std.heap.GeneralPurposeAllocator(.{}), | |
port: u16, | |
start_runnable: ThreadPool.Runnable, | |
fn init(self: *Server, pool: *ThreadPool, comptime port: u16) !void { | |
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
errdefer std.os.close(self.fd); | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(self.fd, 128); | |
self.gpa = .{}; | |
self.pool = pool; | |
self.io_runnable = ThreadPool.IoRunnable{ | |
.is_readable = true, | |
.runnable = .{ .runFn = Server.run }, | |
}; | |
try pool.waitFor(self.fd, &self.io_runnable); | |
self.port = port; | |
self.start_runnable = .{ .runFn = Server.start }; | |
pool.schedule(.{}, &self.start_runnable); | |
} | |
fn start(runnable: *ThreadPool.Runnable) void { | |
const self = @fieldParentPtr(Server, "start_runnable", runnable); | |
std.debug.warn("Listening on :{}\n", .{self.port}); | |
} | |
fn run(runnable: *ThreadPool.Runnable) void { | |
const io_runnable = @fieldParentPtr(ThreadPool.IoRunnable, "runnable", runnable); | |
const self = @fieldParentPtr(Server, "io_runnable", io_runnable); | |
self.accept() catch |err| { | |
std.os.close(self.fd); | |
std.debug.warn("Server shutdown\n", .{}); | |
}; | |
} | |
fn accept(self: *Server) !void { | |
if (self.io_runnable.is_closable) | |
return error.ServerShutdown; | |
if (!self.io_runnable.is_readable) | |
unreachable; | |
while (true) { | |
const client_fd = std.os.accept(self.fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) { | |
error.WouldBlock => break, | |
else => |e| return e, | |
}; | |
Client.init(client_fd, self) catch |err| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to spawn client: {}\n", .{err}); | |
}; | |
} | |
try self.pool.waitFor(self.fd, &self.io_runnable); | |
} | |
}; | |
const Client = struct { | |
fd: std.os.socket_t, | |
server: *Server, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
io_runnable: ThreadPool.IoRunnable, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn init(fd: std.os.socket_t, server: *Server) !void { | |
const allocator = &server.gpa.allocator; | |
const self = try allocator.create(Client); | |
errdefer allocator.destroy(self); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
self.* = .{ | |
.fd = fd, | |
.server = server, | |
.io_runnable = .{ | |
.is_readable = true, | |
.runnable = .{ .runFn = Client.run }, | |
}, | |
}; | |
try self.server.pool.waitFor(self.fd, &self.io_runnable); | |
} | |
fn run(runnable: *ThreadPool.Runnable) void { | |
const io_runnable = @fieldParentPtr(ThreadPool.IoRunnable, "runnable", runnable); | |
const self = @fieldParentPtr(Client, "io_runnable", io_runnable); | |
self.process() catch { | |
std.os.close(self.fd); | |
self.server.gpa.allocator.destroy(self); | |
}; | |
} | |
fn process(self: *Client) !void { | |
if (self.io_runnable.is_closable) | |
return error.Closed; | |
var written = false; | |
if (self.io_runnable.is_writable and (self.send_bytes > 0)) { | |
written = true; | |
try self.processWrite(); | |
} | |
if (self.io_runnable.is_readable) | |
try self.processRead(); | |
if (!written and self.io_runnable.is_writable and (self.send_bytes > 0)) | |
try self.processWrite(); | |
self.io_runnable.is_readable = true; | |
self.io_runnable.is_writable = self.send_bytes > 0; | |
try self.server.pool.waitFor(self.fd, &self.io_runnable); | |
} | |
fn processRead(self: *Client) !void { | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
self.send_bytes += HTTP_RESPONSE.len; | |
continue; | |
} | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
const bytes_read = std.os.read(self.fd, readable_buffer) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.recv_bytes += bytes_read; | |
if (bytes_read == 0) | |
return error.EndOfStream; | |
} | |
} | |
fn processWrite(self: *Client) !void { | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
while (self.send_bytes > 0) { | |
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial | |
const send_bytes = self.send_bytes; | |
if (self.send_partial > RESPONSE_CHUNK.len) | |
std.debug.panic("invalid send_partial={} chunk={}\n", .{self.send_partial, RESPONSE_CHUNK.len}); | |
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
const iov_len = std.math.min(send_bytes, RESPONSE_CHUNK.len - self.send_partial); | |
const writable_buffer = iov_base[0..iov_len]; | |
// Perform the actual write. | |
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing. | |
const bytes_written = std.os.sendto( | |
self.fd, | |
writable_buffer, | |
std.os.MSG_NOSIGNAL, | |
null, | |
@as(std.os.socklen_t, 0), | |
) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
self.send_bytes -= bytes_written; | |
} | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const ThreadPool = @import("./ThreadPool.zig"); | |
pub fn main() !void { | |
var poller: Poller = undefined; | |
try poller.init(); | |
defer poller.deinit(); | |
var server: Server = undefined; | |
try server.init(&poller, 12345); | |
var pool = ThreadPool.init(.{}); | |
defer pool.deinit(); | |
while (true) { | |
var events: [1024]std.os.epoll_event = undefined; | |
const found = std.os.epoll_wait(poller.fd, &events, -1); | |
var batch = ThreadPool.Batch{}; | |
defer if (!batch.isEmpty()) | |
pool.schedule(.{}, batch); | |
for (events[0..found]) |event| { | |
const socket = @intToPtr(*Socket, event.data.ptr); | |
socket.events = event.events; | |
batch.push(&socket.runnable); | |
} | |
} | |
} | |
const Poller = struct { | |
fd: std.os.fd_t, | |
gpa: std.heap.GeneralPurposeAllocator(.{}) = .{}, | |
fn init(self: *Poller) !void { | |
const fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC); | |
self.* = .{ .fd = fd }; | |
} | |
fn deinit(self: *Poller) void { | |
std.os.close(self.fd); | |
_ = self.gpa.deinit(); | |
} | |
fn getAllocator(self: *Poller) *std.mem.Allocator { | |
return &self.gpa.allocator; | |
} | |
}; | |
const Socket = struct { | |
fd: std.os.socket_t, | |
poller: *Poller, | |
events: u32, | |
runnable: ThreadPool.Runnable, | |
fn init(self: *Socket, fd: std.os.socket_t, poller: *Poller, comptime Container: type) !void { | |
const Callback = struct { | |
fn runFn(runnable: *ThreadPool.Runnable) void { | |
const socket = @fieldParentPtr(Socket, "runnable", runnable); | |
const container = @fieldParentPtr(Container, "socket", socket); | |
container.run() catch { | |
std.os.close(socket.fd); | |
container.deinit(); | |
}; | |
} | |
}; | |
self.* = .{ | |
.fd = fd, | |
.poller = poller, | |
.events = undefined, | |
.runnable = .{ .runFn = Callback.runFn }, | |
}; | |
try std.os.epoll_ctl(poller.fd, std.os.EPOLL_CTL_ADD, fd, &std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLOUT | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(self) }, | |
}); | |
} | |
fn register(self: *Socket, events: u32) !void { | |
try std.os.epoll_ctl(self.poller.fd, std.os.EPOLL_CTL_MOD, self.fd, &std.os.epoll_event{ | |
.events = events | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(self) }, | |
}); | |
} | |
}; | |
const Server = struct { | |
socket: Socket, | |
fn init(self: *Server, poller: *Poller, comptime port: u16) !void { | |
const fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
errdefer std.os.close(fd); | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(fd, 128); | |
try self.socket.init(fd, poller, Server); | |
std.debug.warn("Listening on :{}\n", .{port}); | |
} | |
fn deinit(self: *Server) void { | |
std.debug.warn("Server shutdown\n", .{}); | |
} | |
fn run(self: *Server) !void { | |
const events = self.socket.events; | |
const server_fd = self.socket.fd; | |
if (events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0) | |
return error.ServerShutdown; | |
if (events & std.os.EPOLLIN == 0) | |
unreachable; | |
while (true) { | |
const client_fd = std.os.accept(server_fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) { | |
error.WouldBlock => return self.socket.register(std.os.EPOLLIN), | |
else => |e| return e, | |
}; | |
Client.init(client_fd, self.socket.poller) catch |err| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to spawn client: {}\n", .{err}); | |
}; | |
} | |
} | |
}; | |
const Client = struct { | |
socket: Socket, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn init(fd: std.os.socket_t, poller: *Poller) !void { | |
const allocator = poller.getAllocator(); | |
const self = try allocator.create(Client); | |
errdefer allocator.destroy(self); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
self.* = .{ .socket = undefined }; | |
try self.socket.init(fd, poller, Client); | |
} | |
fn deinit(self: *Client) void { | |
const allocator = self.socket.poller.getAllocator(); | |
allocator.destroy(self); | |
} | |
fn run(self: *Client) !void { | |
var events = self.socket.events; | |
if (events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0) | |
return error.Closed; | |
var written = false; | |
if ((events & std.os.EPOLLOUT != 0) and (self.send_bytes > 0)) { | |
written = true; | |
try self.processWrite(); | |
} | |
if (events & std.os.EPOLLIN != 0) | |
try self.processRead(); | |
if (!written and (events & std.os.EPOLLOUT != 0) and (self.send_bytes > 0)) | |
try self.processWrite(); | |
events = std.os.EPOLLIN; | |
if (self.send_bytes > 0) | |
events |= std.os.EPOLLOUT; | |
try self.socket.register(events); | |
} | |
fn processRead(self: *Client) !void { | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
self.send_bytes += HTTP_RESPONSE.len; | |
continue; | |
} | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
const bytes_read = std.os.read(self.socket.fd, readable_buffer) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.recv_bytes += bytes_read; | |
if (bytes_read == 0) | |
return error.EndOfStream; | |
} | |
} | |
fn processWrite(self: *Client) !void { | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
while (self.send_bytes > 0) { | |
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial | |
const send_bytes = self.send_bytes; | |
if (self.send_partial > RESPONSE_CHUNK.len) | |
std.debug.panic("invalid send_partial={} chunk={}\n", .{self.send_partial, RESPONSE_CHUNK.len}); | |
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
const iov_len = std.math.min(send_bytes, RESPONSE_CHUNK.len - self.send_partial); | |
const writable_buffer = iov_base[0..iov_len]; | |
// Perform the actual write. | |
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing. | |
const bytes_written = std.os.sendto( | |
self.socket.fd, | |
writable_buffer, | |
std.os.MSG_NOSIGNAL, | |
null, | |
@as(std.os.socklen_t, 0), | |
) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
self.send_bytes -= bytes_written; | |
} | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const system = switch (std.builtin.os.tag) { | |
.linux => std.os.linux, | |
else => std.os.system, | |
}; | |
const ThreadPool = @This(); | |
max_threads: u16, | |
counter: u32 = 0, | |
spawned_queue: ?*Worker = null, | |
run_queue: UnboundedQueue = .{}, | |
idle_semaphore: Semaphore = Semaphore.init(0), | |
shutdown_event: Event = .{}, | |
pub const InitConfig = struct { | |
max_threads: ?u16 = null, | |
}; | |
pub fn init(config: InitConfig) ThreadPool { | |
return .{ | |
.max_threads = std.math.min( | |
std.math.maxInt(u14), | |
std.math.max(1, config.max_threads orelse blk: { | |
break :blk @intCast(u16, std.Thread.cpuCount() catch 1); | |
}), | |
), | |
}; | |
} | |
pub fn deinit(self: *ThreadPool) void { | |
defer self.* = undefined; | |
self.shutdown(); | |
self.shutdown_event.wait(); | |
while (self.spawned_queue) |worker| { | |
self.spawned_queue = worker.spawned_next; | |
const thread = worker.thread; | |
worker.shutdown_event.notify(); | |
thread.wait(); | |
} | |
} | |
pub const ScheduleHints = struct { | |
priority: Priority = .Normal, | |
pub const Priority = enum { | |
High, | |
Normal, | |
Low, | |
}; | |
}; | |
pub fn schedule(self: *ThreadPool, hints: ScheduleHints, batchable: anytype) void { | |
const batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (Worker.current) |worker| { | |
worker.push(hints, batch); | |
} else { | |
self.run_queue.push(batch); | |
} | |
_ = self.tryNotifyWith(false); | |
} | |
pub const SpawnConfig = struct { | |
allocator: *std.mem.Allocator, | |
hints: ScheduleHints = .{}, | |
}; | |
pub fn spawn(self: *ThreadPool, config: SpawnConfig, comptime func: anytype, args: anytype) !void { | |
const Args = @TypeOf(args); | |
const is_async = @typeInfo(@TypeOf(func)).Fn.calling_convention == .Async; | |
const Closure = struct { | |
func_args: Args, | |
allocator: *std.mem.Allocator, | |
runnable: Runnable = .{ .runFn = runFn }, | |
frame: if (is_async) @Frame(runAsyncFn) else void = undefined, | |
fn runFn(runnable: *Runnable) void { | |
const closure = @fieldParentPtr(@This(), "runnable", runnable); | |
if (is_async) { | |
closure.frame = async closure.runAsyncFn(); | |
} else { | |
const result = @call(.{}, func, closure.func_args); | |
closure.allocator.destroy(closure); | |
} | |
} | |
fn runAsyncFn(closure: *@This()) void { | |
const result = @call(.{}, func, closure.func_args); | |
suspend closure.allocator.destroy(closure); | |
} | |
}; | |
const allocator = config.allocator; | |
const closure = try allocator.create(Closure); | |
errdefer allocator.destroy(closure); | |
closure.* = .{ | |
.func_args = args, | |
.allocator = allocator, | |
}; | |
const hints = config.hints; | |
self.schedule(hints, &closure.runnable); | |
} | |
const Counter = struct { | |
state: State = .pending, | |
idle: u16 = 0, | |
spawned: u16 = 0, | |
const State = enum(u4) { | |
pending = 0, | |
notified, | |
waking, | |
waker_notified, | |
shutdown, | |
}; | |
fn pack(self: Counter) u32 { | |
return (@as(u32, @as(u4, @enumToInt(self.state))) | | |
(@as(u32, @intCast(u14, self.idle)) << 4) | | |
(@as(u32, @intCast(u14, self.spawned)) << (4 + 14))); | |
} | |
fn unpack(value: u32) Counter { | |
return Counter{ | |
.state = @intToEnum(State, @truncate(u4, value)), | |
.idle = @as(u16, @truncate(u14, value >> 4)), | |
.spawned = @as(u16, @truncate(u14, value >> (4 + 14))), | |
}; | |
} | |
}; | |
fn tryNotifyWith(self: *ThreadPool, is_caller_waking: bool) bool { | |
var spawned = false; | |
var remaining_attempts: u8 = 5; | |
var is_waking = is_caller_waking; | |
while (true) : (yieldCpu()) { | |
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
if (counter.state == .shutdown) { | |
if (spawned) | |
self.releaseWorker(); | |
return false; | |
} | |
const has_pending = (counter.idle > 0) or (counter.spawned < self.max_threads); | |
const can_wake = (is_waking and remaining_attempts > 0) or (!is_waking and counter.state == .pending); | |
if (has_pending and can_wake) { | |
var new_counter = counter; | |
new_counter.state = .waking; | |
if (counter.idle > 0) { | |
new_counter.idle -= 1; | |
} else if (!spawned) { | |
new_counter.spawned += 1; | |
} | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Acquire, | |
.Monotonic, | |
)) |failed| { | |
continue; | |
} | |
is_waking = true; | |
if (counter.idle > 0) { | |
self.idle_semaphore.post(1) catch unreachable; | |
return true; | |
} | |
spawned = true; | |
if (Worker.spawn(self)) | |
return true; | |
remaining_attempts -= 1; | |
continue; | |
} | |
var new_counter = counter; | |
if (is_waking) { | |
new_counter.state = if (can_wake) .pending else .notified; | |
if (spawned) | |
new_counter.spawned -= 1; | |
} else if (counter.state == .waking) { | |
new_counter.state = .waker_notified; | |
} else if (counter.state == .pending) { | |
new_counter.state = .notified; | |
} else { | |
return false; | |
} | |
_ = @cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Monotonic, | |
.Monotonic, | |
) orelse return true; | |
} | |
} | |
const Wait = enum { | |
resumed, | |
notified, | |
shutdown, | |
}; | |
fn tryWaitWith(self: *ThreadPool, is_caller_waking: bool) Wait { | |
var is_waking = is_caller_waking; | |
var counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
while (true) { | |
if (counter.state == .shutdown) { | |
self.releaseWorker(); | |
return .shutdown; | |
} | |
const is_notified = switch (counter.state) { | |
.waker_notified => is_waking, | |
.notified => true, | |
else => false, | |
}; | |
var new_counter = counter; | |
if (is_notified) { | |
new_counter.state = if (is_waking) .waking else .pending; | |
} else { | |
new_counter.idle += 1; | |
if (is_waking) | |
new_counter.state = .pending; | |
} | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Monotonic, | |
.Monotonic, | |
)) |updated| { | |
counter = Counter.unpack(updated); | |
continue; | |
} | |
if (is_notified and is_waking) | |
return .notified; | |
if (is_notified) | |
return .resumed; | |
self.idle_semaphore.wait(1); | |
return .notified; | |
} | |
} | |
fn releaseWorker(self: *ThreadPool) void { | |
const counter_spawned = Counter{ .spawned = 1 }; | |
const counter_value = @atomicRmw(u32, &self.counter, .Sub, counter_spawned.pack(), .AcqRel); | |
const counter = Counter.unpack(counter_value); | |
if (counter.state != .shutdown) | |
std.debug.panic("ThreadPool.releaseWorker() when not shutdown: {}", .{counter}); | |
if (counter.spawned == 1) | |
self.shutdown_event.notify(); | |
} | |
pub fn shutdown(self: *ThreadPool) void { | |
while (true) : (yieldCpu()) { | |
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
if (counter.state == .shutdown) | |
return; | |
var new_counter = counter; | |
new_counter.state = .shutdown; | |
new_counter.idle = 0; | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Acquire, | |
.Monotonic, | |
)) |failed| { | |
continue; | |
} | |
self.idle_semaphore.post(self.max_threads) catch unreachable; | |
return; | |
} | |
} | |
const Worker = struct { | |
pool: *ThreadPool, | |
thread: *std.Thread, | |
spawned_next: ?*Worker = null, | |
shutdown_event: Event = .{}, | |
run_queue: BoundedQueue = .{}, | |
run_queue_next: ?*Runnable = null, | |
run_queue_lifo: ?*Runnable = null, | |
run_queue_overflow: UnboundedQueue = .{}, | |
tick: usize = undefined, | |
is_waking: bool = true, | |
next_target: ?*Worker = null, | |
threadlocal var current: ?*Worker = null; | |
fn spawn(pool: *ThreadPool) bool { | |
const Spawner = struct { | |
thread: *std.Thread = undefined, | |
thread_pool: *ThreadPool, | |
data_put_event: Event = .{}, | |
data_get_event: Event = .{}, | |
fn entry(self: *@This()) void { | |
self.data_put_event.wait(); | |
const thread = self.thread; | |
const thread_pool = self.thread_pool; | |
self.data_get_event.notify(); | |
Worker.run(thread, thread_pool); | |
} | |
}; | |
var spawner = Spawner{ .thread_pool = pool }; | |
spawner.thread = std.Thread.spawn(&spawner, Spawner.entry) catch return false; | |
spawner.data_put_event.notify(); | |
spawner.data_get_event.wait(); | |
return true; | |
} | |
fn run(thread: *std.Thread, pool: *ThreadPool) void { | |
var self = Worker{ | |
.thread = thread, | |
.pool = pool, | |
}; | |
self.tick = @ptrToInt(&self); | |
current = &self; | |
defer current = null; | |
var spawned_queue = @atomicLoad(?*Worker, &pool.spawned_queue, .Monotonic); | |
while (true) { | |
self.spawned_next = spawned_queue; | |
spawned_queue = @cmpxchgWeak( | |
?*Worker, | |
&pool.spawned_queue, | |
spawned_queue, | |
&self, | |
.Release, | |
.Monotonic, | |
) orelse break; | |
} | |
while (true) { | |
if (self.pop()) |runnable| { | |
if (self.is_waking) { | |
self.is_waking = false; | |
_ = pool.tryNotifyWith(true); | |
} | |
self.tick +%= 1; | |
runnable.run(); | |
continue; | |
} | |
self.is_waking = switch (pool.tryWaitWith(self.is_waking)) { | |
.resumed => false, | |
.notified => true, | |
.shutdown => { | |
self.shutdown_event.wait(); | |
break; | |
}, | |
}; | |
} | |
} | |
fn push(self: *Worker, hints: ScheduleHints, batchable: anytype) void { | |
var batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (hints.priority == .High) { | |
const new_lifo = batch.pop(); | |
if (@atomicLoad(?*Runnable, &self.run_queue_lifo, .Monotonic) == null) { | |
@atomicStore(?*Runnable, &self.run_queue_lifo, new_lifo, .Release); | |
} else if (@atomicRmw(?*Runnable, &self.run_queue_lifo, .Xchg, new_lifo, .AcqRel)) |old_lifo| { | |
batch.pushFront(old_lifo); | |
} | |
} | |
if (hints.priority == .Low) { | |
if (self.run_queue_next) |old_next| | |
batch.pushFront(old_next); | |
self.run_queue_next = null; | |
self.run_queue_next = self.pop() orelse batch.pop(); | |
} | |
if (self.run_queue.push(batch)) |overflowed| | |
self.run_queue_overflow.push(overflowed); | |
} | |
fn pop(self: *Worker) ?*Runnable { | |
if (self.tick % 127 == 0) { | |
if (self.popAndStealFromOthers()) |runnable| | |
return runnable; | |
} | |
if (self.tick % 61 == 0) { | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
} | |
if (self.tick % 31 == 0) { | |
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable| | |
return runnable; | |
} | |
if (self.tick % 13 == 0) { | |
if (self.popAndStealLifo(self)) |runnable| | |
return runnable; | |
} | |
if (self.run_queue.pop()) |runnable| | |
return runnable; | |
if (self.popAndStealLifo(self)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
if (self.popAndStealFromOthers()) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
return null; | |
} | |
fn popAndStealLifo(self: *Worker, target: *Worker) ?*Runnable { | |
var run_queue_lifo = @atomicLoad(?*Runnable, &target.run_queue_lifo, .Monotonic); | |
while (true) { | |
if (run_queue_lifo == null) | |
return null; | |
run_queue_lifo = @cmpxchgWeak( | |
?*Runnable, | |
&target.run_queue_lifo, | |
run_queue_lifo, | |
null, | |
.Acquire, | |
.Monotonic, | |
) orelse return run_queue_lifo; | |
} | |
} | |
fn popAndStealFromOthers(self: *Worker) ?*Runnable { | |
var num_workers = blk: { | |
const counter_value = @atomicLoad(u32, &self.pool.counter, .Monotonic); | |
const counter = Counter.unpack(counter_value); | |
break :blk counter.spawned; | |
}; | |
while (num_workers > 0) : (num_workers -= 1) { | |
const target = self.next_target orelse blk: { | |
break :blk @atomicLoad(?*Worker, &self.pool.spawned_queue, .Acquire) orelse { | |
std.debug.panic("Worker observed empty spawned queue when work-stealing", .{}); | |
}; | |
}; | |
self.next_target = target.spawned_next; | |
if (target == self) | |
continue; | |
if (self.run_queue.popAndStealBounded(&target.run_queue)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&target.run_queue_overflow)) |runnable| | |
return runnable; | |
if (self.popAndStealLifo(target)) |runnable| | |
return runnable; | |
} | |
return null; | |
} | |
}; | |
const UnboundedQueue = struct { | |
lock: Mutex = .{}, | |
batch: Batch = .{}, | |
shared_size: usize = 0, | |
fn push(self: *UnboundedQueue, batchable: anytype) void { | |
const batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
const held = self.lock.acquire(); | |
defer held.release(); | |
self.batch.push(batch); | |
var shared_size = self.shared_size; | |
shared_size += batch.size; | |
@atomicStore(usize, &self.shared_size, shared_size, .Release); | |
} | |
fn tryAcquireConsumer(self: *UnboundedQueue) ?Consumer { | |
var shared_size = @atomicLoad(usize, &self.shared_size, .Acquire); | |
if (shared_size == 0) | |
return null; | |
const held = self.lock.acquire(); | |
shared_size = self.shared_size; | |
if (shared_size == 0) { | |
held.release(); | |
return null; | |
} | |
return Consumer{ | |
.held = held, | |
.queue = self, | |
.size = shared_size, | |
}; | |
} | |
const Consumer = struct { | |
held: Mutex.Held, | |
queue: *UnboundedQueue, | |
size: usize, | |
fn release(self: Consumer) void { | |
@atomicStore(usize, &self.queue.shared_size, self.size, .Release); | |
self.held.release(); | |
} | |
fn pop(self: *Consumer) ?*Runnable { | |
const runnable = self.queue.batch.pop() orelse return null; | |
self.size -= 1; | |
return runnable; | |
} | |
}; | |
}; | |
const BoundedQueue = struct { | |
head: usize = 0, | |
tail: usize = 0, | |
buffer: [256]*Runnable = undefined, | |
fn push(self: *BoundedQueue, batchable: anytype) ?Batch { | |
var batch = Batch.from(batchable); | |
while (true) : (yieldCpu()) { | |
if (batch.isEmpty()) | |
return null; | |
var tail = self.tail; | |
var head = @atomicLoad(usize, &self.head, .Acquire); | |
var size = tail -% head; | |
if (size < self.buffer.len) { | |
while (size < self.buffer.len) { | |
const runnable = batch.pop() orelse break; | |
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered); | |
tail +%= 1; | |
size += 1; | |
} | |
@atomicStore(usize, &self.tail, tail, .Release); | |
continue; | |
} | |
var migrate = self.buffer.len / 2; | |
if (@cmpxchgWeak( | |
usize, | |
&self.head, | |
head, | |
head +% migrate, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
var overflowed = Batch{}; | |
while (migrate > 0) : (migrate -= 1) { | |
const runnable = self.buffer[head % self.buffer.len]; | |
overflowed.push(runnable); | |
head +%= 1; | |
} | |
overflowed.push(batch); | |
return overflowed; | |
} | |
} | |
fn pop(self: *BoundedQueue) ?*Runnable { | |
while (true) : (yieldCpu()) { | |
const tail = self.tail; | |
const head = @atomicLoad(usize, &self.head, .Acquire); | |
const size = tail -% head; | |
if (size == 0) | |
return null; | |
if (@cmpxchgWeak( | |
usize, | |
&self.head, | |
head, | |
head +% 1, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
const runnable = self.buffer[head % self.buffer.len]; | |
return runnable; | |
} | |
} | |
fn popAndStealBounded(self: *BoundedQueue, target: *BoundedQueue) ?*Runnable { | |
if (target == self) | |
return self.pop(); | |
const tail = self.tail; | |
const head = @atomicLoad(usize, &self.head, .Acquire); | |
const size = tail -% head; | |
if (size != 0) | |
return self.pop(); | |
while (true) : (yieldThread()) { | |
const target_head = @atomicLoad(usize, &target.head, .Acquire); | |
const target_tail = @atomicLoad(usize, &target.tail, .Acquire); | |
const target_size = target_tail -% target_head; | |
var steal = target_size - (target_size / 2); | |
if (steal == 0) | |
return null; | |
if (steal > target.buffer.len / 2) | |
continue; | |
const first_runnable = @atomicLoad(*Runnable, &target.buffer[target_head % target.buffer.len], .Unordered); | |
var new_target_head = target_head +% 1; | |
var new_tail = tail; | |
steal -= 1; | |
while (steal > 0) : (steal -= 1) { | |
const runnable = @atomicLoad(*Runnable, &target.buffer[new_target_head % target.buffer.len], .Unordered); | |
new_target_head +%= 1; | |
@atomicStore(*Runnable, &self.buffer[new_tail % self.buffer.len], runnable, .Unordered); | |
new_tail +%= 1; | |
} | |
if (@cmpxchgWeak( | |
usize, | |
&target.head, | |
target_head, | |
new_target_head, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
@atomicStore(usize, &self.tail, new_tail, .Release); | |
return first_runnable; | |
} | |
} | |
fn popAndStealUnbounded(self: *BoundedQueue, target: *UnboundedQueue) ?*Runnable { | |
var consumer = target.tryAcquireConsumer() orelse return null; | |
defer consumer.release(); | |
const first_runnable = consumer.pop() orelse return null; | |
var tail = self.tail; | |
var head = @atomicLoad(usize, &self.head, .Acquire); | |
var size = tail -% head; | |
while (size < self.buffer.len) { | |
const runnable = consumer.pop() orelse break; | |
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered); | |
tail +%= 1; | |
size += 1; | |
} | |
@atomicStore(usize, &self.tail, tail, .Release); | |
return first_runnable; | |
} | |
}; | |
pub const Runnable = struct { | |
next: ?*Runnable = null, | |
runFn: fn (*Runnable) void, | |
pub fn run(self: *Runnable) void { | |
return (self.runFn)(self); | |
} | |
}; | |
pub const Batch = struct { | |
head: ?*Runnable = null, | |
tail: *Runnable = undefined, | |
size: usize = 0, | |
pub fn from(batchable: anytype) Batch { | |
return switch (@TypeOf(batchable)) { | |
Batch => batchable, | |
?*Runnable => from(batchable orelse return Batch{}), | |
*Runnable => { | |
batchable.next = null; | |
return Batch{ | |
.head = batchable, | |
.tail = batchable, | |
.size = 1, | |
}; | |
}, | |
else => |typ| @compileError(@typeName(typ) ++ | |
" cannot be converted into " ++ | |
@typeName(Batch)), | |
}; | |
} | |
pub fn isEmpty(self: Batch) bool { | |
return self.head == null; | |
} | |
pub const push = pushBack; | |
pub fn pushBack(self: *Batch, batchable: anytype) void { | |
const batch = from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (self.isEmpty()) { | |
self.* = batch; | |
} else { | |
self.tail.next = batch.head; | |
self.tail = batch.tail; | |
self.size += batch.size; | |
} | |
} | |
pub fn pushFront(self: *Batch, batchable: anytype) void { | |
const batch = from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (self.isEmpty()) { | |
self.* = batch; | |
} else { | |
batch.tail.next = self.head; | |
self.head = batch.head; | |
self.size += batch.size; | |
} | |
} | |
pub const pop = popFront; | |
pub fn popFront(self: *Batch) ?*Runnable { | |
const runnable = self.head orelse return null; | |
self.head = runnable.next; | |
self.size -= 1; | |
return runnable; | |
} | |
}; | |
const Semaphore = struct { | |
lock: Mutex = .{}, | |
permits: usize = 0, | |
waiters: ?*Waiter = null, | |
const Waiter = struct { | |
next: ?*Waiter = null, | |
tail: *Waiter = undefined, | |
event: Event = .{}, | |
permits: usize, | |
}; | |
fn init(permits: usize) Semaphore { | |
return .{ .permits = permits }; | |
} | |
fn wait(self: *Semaphore, permits: usize) void { | |
const held = self.lock.acquire(); | |
if (self.permits >= permits) { | |
self.permits -= permits; | |
held.release(); | |
return; | |
} | |
var waiter = Waiter{ .permits = permits }; | |
if (self.waiters) |head| { | |
head.tail.next = &waiter; | |
head.tail = &waiter; | |
} else { | |
self.waiters = &waiter; | |
waiter.tail = &waiter; | |
} | |
held.release(); | |
waiter.event.wait(); | |
} | |
fn post(self: *Semaphore, permits: usize) error{Overflow}!void { | |
var waiters: ?*Waiter = null; | |
{ | |
const held = self.lock.acquire(); | |
defer held.release(); | |
if (@addWithOverflow(usize, self.permits, permits, &self.permits)) | |
return error.Overflow; | |
while (self.waiters) |waiter| { | |
if (waiter.permits > self.permits) | |
break; | |
self.waiters = waiter.next; | |
if (self.waiters) |new_waiter| | |
new_waiter.tail = waiter.tail; | |
self.permits -= waiter.permits; | |
waiter.next = waiters; | |
waiters = waiter; | |
} | |
} | |
while (waiters) |waiter| { | |
waiters = waiter.next; | |
waiter.event.notify(); | |
} | |
} | |
}; | |
const Mutex = if (std.builtin.os.tag == .windows) | |
struct { | |
srwlock: usize = 0, | |
pub fn acquire(self: *Mutex) Held { | |
AcquireSRWLockExclusive(&self.srwlock); | |
return Mutex{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
ReleaseSRWLockExclusive(&self.mutex.srwlock); | |
} | |
}; | |
extern "kernel32" fn AcquireSRWLockExclusive( | |
srwlock: *?system.PVOID, | |
) callconv(system.WINAPI) void; | |
extern "kernel32" fn ReleaseSRWLockExclusive( | |
srwlock: *?system.PVOID, | |
) callconv(system.WINAPI) void; | |
} | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
lock: u32 = 0, | |
pub fn acquire(self: *Mutex) Held { | |
os_unfair_lock_lock(&self.lock); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
os_unfair_lock_unlock(&self.mutex.lock); | |
} | |
}; | |
extern "c" fn os_unfair_lock_lock( | |
unfair_lock: *u32, | |
) callconv(.C) void; | |
extern "c" fn os_unfair_lock_unlock( | |
unfair_lock: *u32, | |
) callconv(.C) void; | |
} | |
else if (std.builtin.os.tag == .linux) | |
struct { | |
state: i32 = UNLOCKED, | |
const UNLOCKED: i32 = 0; | |
const LOCKED: i32 = 1; | |
const WAITING: i32 = 2; | |
pub fn acquire(self: *Mutex) Held { | |
const state = @atomicRmw(i32, &self.state, .Xchg, LOCKED, .Acquire); | |
if (state != UNLOCKED) | |
self.acquireSlow(state); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
switch (@atomicRmw(i32, &self.mutex.state, .Xchg, UNLOCKED, .Release)) { | |
UNLOCKED => unreachable, // unlocked an unlocked mutex | |
LOCKED => {}, | |
WAITING => self.mutex.releaseSlow(), | |
else => unreachable, | |
} | |
} | |
}; | |
fn acquireSlow(self: *Mutex, current_state: i32) void { | |
@setCold(true); | |
var wait_state = current_state; | |
while (true) { | |
var spin: u8 = 0; | |
while (spin < 5) : (spin += 1) { | |
switch (@atomicLoad(i32, &self.state, .Monotonic)) { | |
UNLOCKED => _ = @cmpxchgWeak( | |
i32, | |
&self.state, | |
UNLOCKED, | |
wait_state, | |
.Acquire, | |
.Monotonic, | |
) orelse return, | |
LOCKED => {}, | |
WAITING => break, | |
else => unreachable, | |
} | |
if (spin < 4) { | |
var pause: u8 = 30; | |
while (pause > 0) : (pause -= 1) | |
yieldCpu(); | |
} else { | |
yieldThread(); | |
} | |
} | |
const state = @atomicRmw(i32, &self.state, .Xchg, WAITING, .Acquire); | |
if (state == UNLOCKED) | |
return; | |
wait_state = WAITING; | |
switch (system.getErrno(system.futex_wait( | |
&self.state, | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT, | |
WAITING, | |
null, | |
))) { | |
0 => {}, | |
system.EINTR => {}, | |
system.EAGAIN => {}, | |
else => unreachable, | |
} | |
} | |
} | |
fn releaseSlow(self: *Mutex) void { | |
@setCold(true); | |
while (true) { | |
return switch (system.getErrno(system.futex_wake( | |
&self.state, | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE, | |
@as(i32, 1), | |
))) { | |
0 => {}, | |
system.EINTR => continue, | |
system.EFAULT => {}, | |
else => unreachable, | |
}; | |
} | |
} | |
} | |
else | |
struct { | |
locked: bool = false, | |
pub fn acquire(self: *Mutex) Held { | |
while (@atomicRmw(bool, &self.locked, .Xchg, true, .Acquire)) | |
yieldThread(); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
@atomicStore(bool, &self.mutex.locked, false, .Release); | |
} | |
}; | |
}; | |
const Event = if (std.builtin.os.tag == .windows) | |
struct { | |
key: u32 = undefined, | |
pub fn wait(self: *Event) void { | |
const status = NtWaitForKeyedEvent(null, &self.key, system.FALSE, null); | |
std.debug.assert(status == .SUCCESS); | |
} | |
pub fn notify(self: *Event) void { | |
const status = NtReleaseKeyedEvent(null, &self.key, system.FALSE, null); | |
std.debug.assert(status == .SUCCESS); | |
} | |
extern "NtDll" fn NtWaitForKeyedEvent( | |
handle: ?system.HANDLE, | |
key: ?*const u32, | |
alertable: system.BOOLEAN, | |
timeout: ?*const system.LARGE_INTEGER, | |
) callconv(system.WINAPI) system.NTSTATUS; | |
extern "NtDll" fn NtReleaseKeyedEvent( | |
handle: ?system.HANDLE, | |
key: ?*const u32, | |
alertable: system.BOOLEAN, | |
timeout: ?*const system.LARGE_INTEGER, | |
) callconv(system.WINAPI) system.NTSTATUS; | |
} | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
state: enum(u32) { | |
pending = 0, | |
notified, | |
} = .pending, | |
pub fn wait(self: *Event) void { | |
while (true) { | |
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) { | |
.pending => {}, | |
.notified => { | |
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic); | |
return; | |
}, | |
} | |
const status = __ulock_wait( | |
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO, | |
@ptrCast(?*const c_void, &self.state), | |
@enumToInt(@TypeOf(self.state).pending), | |
~@as(u32, 0), | |
); | |
if (status < 0) { | |
switch (-status) { | |
system.EINTR => {}, | |
else => unreachable, | |
} | |
} | |
} | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release); | |
while (true) { | |
const status = __ulock_wake( | |
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO, | |
@ptrCast(?*const c_void, &self.state), | |
@as(u32, 0), | |
); | |
if (status < 0) { | |
switch (-status) { | |
system.ENOENT => {}, | |
system.EINTR => continue, | |
else => unreachable, | |
} | |
} | |
return; | |
} | |
} | |
const ULF_NO_ERRNO = 0x1000000; | |
const UL_COMPARE_AND_WAIT = 0x1; | |
extern "c" fn __ulock_wait( | |
operation: u32, | |
address: ?*const c_void, | |
value: u64, | |
timeout_us: u32, | |
) callconv(.C) c_int; | |
extern "c" fn __ulock_wake( | |
operation: u32, | |
address: ?*const c_void, | |
value: u64, | |
) callconv(.C) c_int; | |
} | |
else if (std.builtin.os.tag == .linux) | |
struct { | |
state: enum(i32) { | |
pending, | |
notified, | |
} = .pending, | |
pub fn wait(self: *Event) void { | |
while (true) { | |
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) { | |
.pending => {}, | |
.notified => { | |
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic); | |
return; | |
}, | |
} | |
switch (system.getErrno(system.futex_wait( | |
@ptrCast(*const i32, &self.state), | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT, | |
@enumToInt(@TypeOf(self.state).pending), | |
null, | |
))) { | |
0 => {}, | |
system.EINTR => {}, | |
system.EAGAIN => {}, | |
else => unreachable, | |
} | |
} | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release); | |
while (true) { | |
return switch (system.getErrno(system.futex_wake( | |
@ptrCast(*const i32, &self.state), | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE, | |
@as(i32, 1), | |
))) { | |
0 => {}, | |
system.EINTR => continue, | |
system.EFAULT => {}, | |
else => unreachable, | |
}; | |
} | |
} | |
} | |
else | |
struct { | |
notified: bool = false, | |
pub fn wait(self: *Event) void { | |
while (!@atomicLoad(bool, &self.notified, .Acquire)) | |
yieldThread(); | |
@atomicStore(bool, &self.notified, false, .Monotonic); | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(bool, &self.notified, true, .Release); | |
} | |
}; | |
const yieldThread = if (std.builtin.os.tag == .windows) | |
struct { | |
fn yield() void { | |
system.kernel32.Sleep(0); | |
} | |
}.yield | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
fn yield() void { | |
_ = thread_switch(MACH_PORT_NULL, SWITCH_OPTION_DEPRESS, 1); | |
} | |
const MACH_PORT_NULL = 0; | |
const SWITCH_OPTION_DEPRESS = 1; | |
// https://www.gnu.org/software/hurd/gnumach-doc/Hand_002dOff-Scheduling.html | |
extern "c" fn thread_switch( | |
thread: usize, | |
options: c_int, | |
timeout_ms: c_int, | |
) callconv(.C) c_int; | |
}.yield | |
else if (std.builtin.os.tag == .linux or std.builtin.link_libc) | |
struct { | |
fn yield() void { | |
_ = system.sched_yield(); | |
} | |
}.yield | |
else | |
yieldCpu; | |
fn yieldCpu() void { | |
switch (std.builtin.arch) { | |
.i386, .x86_64 => asm volatile ("pause"), | |
.arm, .aarch64 => asm volatile ("yield"), | |
else => {}, | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const system = switch (std.builtin.os.tag) { | |
.linux => std.os.linux, | |
else => std.os.system, | |
}; | |
const ThreadPool = @This(); | |
io_driver: IoDriver, | |
max_threads: u16, | |
counter: u32 = 0, | |
spawned_queue: ?*Worker = null, | |
run_queue: UnboundedQueue = .{}, | |
shutdown_event: Event = .{}, | |
pub const InitConfig = struct { | |
max_threads: ?u16 = null, | |
}; | |
pub fn init(config: InitConfig) !ThreadPool { | |
return ThreadPool{ | |
.io_driver = try IoDriver.init(), | |
.max_threads = std.math.min( | |
std.math.maxInt(u14), | |
std.math.max(1, config.max_threads orelse blk: { | |
break :blk @intCast(u16, std.Thread.cpuCount() catch 1); | |
}), | |
), | |
}; | |
} | |
pub fn deinit(self: *ThreadPool) void { | |
defer self.* = undefined; | |
defer self.io_driver.deinit(); | |
self.shutdown(); | |
self.shutdown_event.wait(); | |
while (self.spawned_queue) |worker| { | |
self.spawned_queue = worker.spawned_next; | |
const thread = worker.thread; | |
worker.shutdown_event.notify(); | |
thread.wait(); | |
} | |
} | |
pub const ScheduleHints = struct { | |
priority: Priority = .Normal, | |
pub const Priority = enum { | |
High, | |
Normal, | |
Low, | |
}; | |
}; | |
pub fn schedule(self: *ThreadPool, hints: ScheduleHints, batchable: anytype) void { | |
const batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (Worker.current) |worker| { | |
worker.push(hints, batch); | |
} else { | |
self.run_queue.push(batch); | |
} | |
_ = self.tryNotifyWith(false); | |
} | |
pub const SpawnConfig = struct { | |
allocator: *std.mem.Allocator, | |
hints: ScheduleHints = .{}, | |
}; | |
pub fn spawn(self: *ThreadPool, config: SpawnConfig, comptime func: anytype, args: anytype) !void { | |
const Args = @TypeOf(args); | |
const is_async = @typeInfo(@TypeOf(func)).Fn.calling_convention == .Async; | |
const Closure = struct { | |
func_args: Args, | |
allocator: *std.mem.Allocator, | |
runnable: Runnable = .{ .runFn = runFn }, | |
frame: if (is_async) @Frame(runAsyncFn) else void = undefined, | |
fn runFn(runnable: *Runnable) void { | |
const closure = @fieldParentPtr(@This(), "runnable", runnable); | |
if (is_async) { | |
closure.frame = async closure.runAsyncFn(); | |
} else { | |
const result = @call(.{}, func, closure.func_args); | |
closure.allocator.destroy(closure); | |
} | |
} | |
fn runAsyncFn(closure: *@This()) void { | |
const result = @call(.{}, func, closure.func_args); | |
suspend closure.allocator.destroy(closure); | |
} | |
}; | |
const allocator = config.allocator; | |
const closure = try allocator.create(Closure); | |
errdefer allocator.destroy(closure); | |
closure.* = .{ | |
.func_args = args, | |
.allocator = allocator, | |
}; | |
const hints = config.hints; | |
self.schedule(hints, &closure.runnable); | |
} | |
const Counter = struct { | |
state: State = .pending, | |
notified: bool = false, | |
idle: u16 = 0, | |
spawned: u16 = 0, | |
const State = enum(u3) { | |
pending = 0, | |
notified, | |
waking, | |
waker_notified, | |
shutdown, | |
}; | |
fn pack(self: Counter) u32 { | |
return (@as(u32, @as(u3, @enumToInt(self.state))) | | |
(@as(u32, @boolToInt(self.notified)) << 3) | | |
(@as(u32, @intCast(u14, self.idle)) << 4) | | |
(@as(u32, @intCast(u14, self.spawned)) << (4 + 14))); | |
} | |
fn unpack(value: u32) Counter { | |
return Counter{ | |
.state = @intToEnum(State, @truncate(u3, value)), | |
.notified = value & (1 << 3) != 0, | |
.idle = @as(u16, @truncate(u14, value >> 4)), | |
.spawned = @as(u16, @truncate(u14, value >> (4 + 14))), | |
}; | |
} | |
}; | |
fn tryNotifyWith(self: *ThreadPool, is_caller_waking: bool) bool { | |
var spawned = false; | |
var remaining_attempts: u8 = 5; | |
var is_waking = is_caller_waking; | |
while (true) : (yieldCpu()) { | |
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
if (counter.state == .shutdown) { | |
if (spawned) | |
self.releaseWorker(); | |
return false; | |
} | |
const has_pending = (counter.idle > 0) or (counter.spawned < self.max_threads); | |
const can_wake = (is_waking and remaining_attempts > 0) or (!is_waking and counter.state == .pending); | |
if (has_pending and can_wake) { | |
var new_counter = counter; | |
new_counter.state = .waking; | |
if (counter.idle > 0) { | |
new_counter.idle -= 1; | |
new_counter.notified = true; | |
} else if (!spawned) { | |
new_counter.spawned += 1; | |
} | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Acquire, | |
.Monotonic, | |
)) |failed| { | |
continue; | |
} | |
is_waking = true; | |
if (counter.idle > 0) { | |
self.idleNotify(); | |
return true; | |
} | |
spawned = true; | |
if (Worker.spawn(self)) | |
return true; | |
remaining_attempts -= 1; | |
continue; | |
} | |
var new_counter = counter; | |
if (is_waking) { | |
new_counter.state = if (can_wake) .pending else .notified; | |
if (spawned) | |
new_counter.spawned -= 1; | |
} else if (counter.state == .waking) { | |
new_counter.state = .waker_notified; | |
} else if (counter.state == .pending) { | |
new_counter.state = .notified; | |
} else { | |
return false; | |
} | |
_ = @cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Monotonic, | |
.Monotonic, | |
) orelse return true; | |
} | |
} | |
const Wait = enum { | |
resumed, | |
notified, | |
shutdown, | |
}; | |
fn tryWaitWith(self: *ThreadPool, worker: *Worker) Wait { | |
var is_waking = worker.is_waking; | |
var counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
while (true) { | |
if (counter.state == .shutdown) { | |
self.releaseWorker(); | |
return .shutdown; | |
} | |
const is_notified = switch (counter.state) { | |
.waker_notified => is_waking, | |
.notified => true, | |
else => false, | |
}; | |
var new_counter = counter; | |
if (is_notified) { | |
new_counter.state = if (is_waking) .waking else .pending; | |
} else { | |
new_counter.idle += 1; | |
if (is_waking) | |
new_counter.state = .pending; | |
} | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Monotonic, | |
.Monotonic, | |
)) |updated| { | |
counter = Counter.unpack(updated); | |
continue; | |
} | |
if (is_notified and is_waking) | |
return .notified; | |
if (is_notified) | |
return .resumed; | |
self.idleWait(worker); | |
return .notified; | |
} | |
} | |
fn releaseWorker(self: *ThreadPool) void { | |
const counter_spawned = Counter{ .spawned = 1 }; | |
const counter_value = @atomicRmw(u32, &self.counter, .Sub, counter_spawned.pack(), .AcqRel); | |
const counter = Counter.unpack(counter_value); | |
if (counter.state != .shutdown) | |
std.debug.panic("ThreadPool.releaseWorker() when not shutdown: {}", .{counter}); | |
if (counter.spawned == 1) | |
self.shutdown_event.notify(); | |
} | |
pub fn shutdown(self: *ThreadPool) void { | |
while (true) : (yieldCpu()) { | |
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
if (counter.state == .shutdown) | |
return; | |
var new_counter = counter; | |
new_counter.state = .shutdown; | |
new_counter.idle = 0; | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Acquire, | |
.Monotonic, | |
)) |failed| { | |
continue; | |
} | |
self.idleShutdown(); | |
return; | |
} | |
} | |
fn idleWait(self: *ThreadPool, worker: *Worker) void { | |
var counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
while (true) { | |
if (counter.state == .shutdown) { | |
self.io_driver.notify(); | |
return; | |
} | |
if (counter.notified) { | |
var new_counter = counter; | |
new_counter.notified = false; | |
counter = Counter.unpack(@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Acquire, | |
.Monotonic, | |
) orelse return); | |
continue; | |
} | |
const batch = self.io_driver.wait(); | |
self.schedule(.{}, batch); | |
counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
} | |
} | |
fn idleNotify(self: *ThreadPool) void { | |
self.io_driver.notify(); | |
} | |
fn idleShutdown(self: *ThreadPool) void { | |
self.io_driver.notify(); | |
} | |
pub const IoRunnable = struct { | |
fd: std.os.fd_t = -1, | |
is_closable: bool = false, | |
is_readable: bool = false, | |
is_writable: bool = false, | |
runnable: Runnable, | |
}; | |
pub fn waitFor(self: *ThreadPool, fd: std.os.fd_t, io_runnable: *IoRunnable) !void { | |
return self.io_driver.register(fd, io_runnable); | |
} | |
const IoDriver = struct { | |
poll_fd: std.os.fd_t, | |
notify_fd: std.os.fd_t, | |
fn init() !IoDriver { | |
const poll_fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC); | |
errdefer std.os.close(poll_fd); | |
const notify_fd = try std.os.eventfd(0, std.os.EFD_CLOEXEC | std.os.EFD_NONBLOCK); | |
errdefer std.os.close(notify_fd); | |
var event = std.os.epoll_event{ | |
.events = std.os.EPOLLONESHOT, | |
.data = .{ .ptr = 0 }, | |
}; | |
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_ADD, notify_fd, &event); | |
return IoDriver{ | |
.poll_fd = poll_fd, | |
.notify_fd = notify_fd, | |
}; | |
} | |
fn deinit(self: *IoDriver) void { | |
std.os.close(self.poll_fd); | |
std.os.close(self.notify_fd); | |
} | |
fn notify(self: *IoDriver) void { | |
const fd = self.notify_fd; | |
var event = std.os.epoll_event{ | |
.events = std.os.EPOLLOUT | std.os.EPOLLONESHOT, | |
.data = .{ .ptr = 0 }, | |
}; | |
std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_MOD, fd, &event) catch |err| switch (err) { | |
error.FileDescriptorNotRegistered => { | |
std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_ADD, fd, &event) catch {}; | |
}, | |
else => {}, | |
}; | |
} | |
fn register(self: *IoDriver, fd: std.os.fd_t, io_runnable: *IoRunnable) !void { | |
var events: u32 = std.os.EPOLLONESHOT | std.os.EPOLLRDHUP | std.os.EPOLLHUP; | |
if (io_runnable.is_readable) | |
events |= std.os.EPOLLIN; | |
if (io_runnable.is_writable) | |
events |= std.os.EPOLLOUT; | |
io_runnable.fd = fd; | |
var event = std.os.epoll_event{ | |
.events = events, | |
.data = .{ .ptr = @ptrToInt(io_runnable) }, | |
}; | |
std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_MOD, fd, &event) catch |err| switch (err) { | |
error.FileDescriptorNotRegistered => { | |
try std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_ADD, fd, &event); | |
}, | |
else => |e| return e, | |
}; | |
} | |
fn wait(self: *IoDriver) Batch { | |
var batch = Batch{}; | |
var events: [128]std.os.epoll_event = undefined; | |
const count = std.os.epoll_wait(self.poll_fd, &events, -1); | |
for (events[0..count]) |event| { | |
const io_runnable = @intToPtr(?*IoRunnable, event.data.ptr) orelse continue; | |
io_runnable.is_closable = (event.events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0); | |
io_runnable.is_writable = (event.events & (std.os.EPOLLOUT) != 0); | |
io_runnable.is_readable = (event.events & (std.os.EPOLLIN) != 0); | |
batch.push(&io_runnable.runnable); | |
} | |
return batch; | |
} | |
}; | |
const Worker = struct { | |
pool: *ThreadPool, | |
thread: *std.Thread, | |
spawned_next: ?*Worker = null, | |
shutdown_event: Event = .{}, | |
run_queue: BoundedQueue = .{}, | |
run_queue_next: ?*Runnable = null, | |
run_queue_lifo: ?*Runnable = null, | |
run_queue_overflow: UnboundedQueue = .{}, | |
tick: usize = undefined, | |
is_waking: bool = true, | |
next_target: ?*Worker = null, | |
threadlocal var current: ?*Worker = null; | |
fn spawn(pool: *ThreadPool) bool { | |
const Spawner = struct { | |
thread: *std.Thread = undefined, | |
thread_pool: *ThreadPool, | |
data_put_event: Event = .{}, | |
data_get_event: Event = .{}, | |
fn entry(self: *@This()) void { | |
self.data_put_event.wait(); | |
const thread = self.thread; | |
const thread_pool = self.thread_pool; | |
self.data_get_event.notify(); | |
Worker.run(thread, thread_pool); | |
} | |
}; | |
var spawner = Spawner{ .thread_pool = pool }; | |
spawner.thread = std.Thread.spawn(&spawner, Spawner.entry) catch return false; | |
spawner.data_put_event.notify(); | |
spawner.data_get_event.wait(); | |
return true; | |
} | |
fn run(thread: *std.Thread, pool: *ThreadPool) void { | |
var self = Worker{ | |
.thread = thread, | |
.pool = pool, | |
}; | |
self.tick = @ptrToInt(&self); | |
current = &self; | |
defer current = null; | |
var spawned_queue = @atomicLoad(?*Worker, &pool.spawned_queue, .Monotonic); | |
while (true) { | |
self.spawned_next = spawned_queue; | |
spawned_queue = @cmpxchgWeak( | |
?*Worker, | |
&pool.spawned_queue, | |
spawned_queue, | |
&self, | |
.Release, | |
.Monotonic, | |
) orelse break; | |
} | |
while (true) { | |
if (self.pop()) |runnable| { | |
if (self.is_waking) { | |
self.is_waking = false; | |
_ = pool.tryNotifyWith(true); | |
} | |
self.tick +%= 1; | |
runnable.run(); | |
continue; | |
} | |
self.is_waking = switch (pool.tryWaitWith(&self)) { | |
.resumed => false, | |
.notified => true, | |
.shutdown => { | |
self.shutdown_event.wait(); | |
break; | |
}, | |
}; | |
} | |
} | |
fn push(self: *Worker, hints: ScheduleHints, batchable: anytype) void { | |
var batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (hints.priority == .High) { | |
const new_lifo = batch.pop(); | |
if (@atomicLoad(?*Runnable, &self.run_queue_lifo, .Monotonic) == null) { | |
@atomicStore(?*Runnable, &self.run_queue_lifo, new_lifo, .Release); | |
} else if (@atomicRmw(?*Runnable, &self.run_queue_lifo, .Xchg, new_lifo, .AcqRel)) |old_lifo| { | |
batch.pushFront(old_lifo); | |
} | |
} | |
if (hints.priority == .Low) { | |
if (self.run_queue_next) |old_next| | |
batch.pushFront(old_next); | |
self.run_queue_next = null; | |
self.run_queue_next = self.pop() orelse batch.pop(); | |
} | |
if (self.run_queue.push(batch)) |overflowed| | |
self.run_queue_overflow.push(overflowed); | |
} | |
fn pop(self: *Worker) ?*Runnable { | |
if (self.tick % 127 == 0) { | |
if (self.popAndStealFromOthers()) |runnable| | |
return runnable; | |
} | |
if (self.tick % 61 == 0) { | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
} | |
if (self.tick % 31 == 0) { | |
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable| | |
return runnable; | |
} | |
if (self.tick % 13 == 0) { | |
if (self.popAndStealLifo(self)) |runnable| | |
return runnable; | |
} | |
if (self.run_queue.pop()) |runnable| | |
return runnable; | |
if (self.popAndStealLifo(self)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
if (self.popAndStealFromOthers()) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
return null; | |
} | |
fn popAndStealLifo(self: *Worker, target: *Worker) ?*Runnable { | |
var run_queue_lifo = @atomicLoad(?*Runnable, &target.run_queue_lifo, .Monotonic); | |
while (true) { | |
if (run_queue_lifo == null) | |
return null; | |
run_queue_lifo = @cmpxchgWeak( | |
?*Runnable, | |
&target.run_queue_lifo, | |
run_queue_lifo, | |
null, | |
.Acquire, | |
.Monotonic, | |
) orelse return run_queue_lifo; | |
} | |
} | |
fn popAndStealFromOthers(self: *Worker) ?*Runnable { | |
var num_workers = blk: { | |
const counter_value = @atomicLoad(u32, &self.pool.counter, .Monotonic); | |
const counter = Counter.unpack(counter_value); | |
break :blk counter.spawned; | |
}; | |
while (num_workers > 0) : (num_workers -= 1) { | |
const target = self.next_target orelse blk: { | |
break :blk @atomicLoad(?*Worker, &self.pool.spawned_queue, .Acquire) orelse { | |
std.debug.panic("Worker observed empty spawned queue when work-stealing", .{}); | |
}; | |
}; | |
self.next_target = target.spawned_next; | |
if (target == self) | |
continue; | |
if (self.run_queue.popAndStealBounded(&target.run_queue)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&target.run_queue_overflow)) |runnable| | |
return runnable; | |
if (self.popAndStealLifo(target)) |runnable| | |
return runnable; | |
} | |
return null; | |
} | |
}; | |
const UnboundedQueue = struct { | |
lock: Mutex = .{}, | |
batch: Batch = .{}, | |
shared_size: usize = 0, | |
fn push(self: *UnboundedQueue, batchable: anytype) void { | |
const batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
const held = self.lock.acquire(); | |
defer held.release(); | |
self.batch.push(batch); | |
var shared_size = self.shared_size; | |
shared_size += batch.size; | |
@atomicStore(usize, &self.shared_size, shared_size, .Release); | |
} | |
fn tryAcquireConsumer(self: *UnboundedQueue) ?Consumer { | |
var shared_size = @atomicLoad(usize, &self.shared_size, .Acquire); | |
if (shared_size == 0) | |
return null; | |
const held = self.lock.acquire(); | |
shared_size = self.shared_size; | |
if (shared_size == 0) { | |
held.release(); | |
return null; | |
} | |
return Consumer{ | |
.held = held, | |
.queue = self, | |
.size = shared_size, | |
}; | |
} | |
const Consumer = struct { | |
held: Mutex.Held, | |
queue: *UnboundedQueue, | |
size: usize, | |
fn release(self: Consumer) void { | |
@atomicStore(usize, &self.queue.shared_size, self.size, .Release); | |
self.held.release(); | |
} | |
fn pop(self: *Consumer) ?*Runnable { | |
const runnable = self.queue.batch.pop() orelse return null; | |
self.size -= 1; | |
return runnable; | |
} | |
}; | |
}; | |
const BoundedQueue = struct { | |
head: usize = 0, | |
tail: usize = 0, | |
buffer: [256]*Runnable = undefined, | |
fn push(self: *BoundedQueue, batchable: anytype) ?Batch { | |
var batch = Batch.from(batchable); | |
while (true) : (yieldCpu()) { | |
if (batch.isEmpty()) | |
return null; | |
var tail = self.tail; | |
var head = @atomicLoad(usize, &self.head, .Acquire); | |
var size = tail -% head; | |
if (size < self.buffer.len) { | |
while (size < self.buffer.len) { | |
const runnable = batch.pop() orelse break; | |
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered); | |
tail +%= 1; | |
size += 1; | |
} | |
@atomicStore(usize, &self.tail, tail, .Release); | |
continue; | |
} | |
var migrate = self.buffer.len / 2; | |
if (@cmpxchgWeak( | |
usize, | |
&self.head, | |
head, | |
head +% migrate, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
var overflowed = Batch{}; | |
while (migrate > 0) : (migrate -= 1) { | |
const runnable = self.buffer[head % self.buffer.len]; | |
overflowed.push(runnable); | |
head +%= 1; | |
} | |
overflowed.push(batch); | |
return overflowed; | |
} | |
} | |
fn pop(self: *BoundedQueue) ?*Runnable { | |
while (true) : (yieldCpu()) { | |
const tail = self.tail; | |
const head = @atomicLoad(usize, &self.head, .Acquire); | |
const size = tail -% head; | |
if (size == 0) | |
return null; | |
if (@cmpxchgWeak( | |
usize, | |
&self.head, | |
head, | |
head +% 1, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
const runnable = self.buffer[head % self.buffer.len]; | |
return runnable; | |
} | |
} | |
fn popAndStealBounded(self: *BoundedQueue, target: *BoundedQueue) ?*Runnable { | |
if (target == self) | |
return self.pop(); | |
const tail = self.tail; | |
const head = @atomicLoad(usize, &self.head, .Acquire); | |
const size = tail -% head; | |
if (size != 0) | |
return self.pop(); | |
while (true) : (yieldThread()) { | |
const target_head = @atomicLoad(usize, &target.head, .Acquire); | |
const target_tail = @atomicLoad(usize, &target.tail, .Acquire); | |
const target_size = target_tail -% target_head; | |
var steal = target_size - (target_size / 2); | |
if (steal == 0) | |
return null; | |
if (steal > target.buffer.len / 2) | |
continue; | |
const first_runnable = @atomicLoad(*Runnable, &target.buffer[target_head % target.buffer.len], .Unordered); | |
var new_target_head = target_head +% 1; | |
var new_tail = tail; | |
steal -= 1; | |
while (steal > 0) : (steal -= 1) { | |
const runnable = @atomicLoad(*Runnable, &target.buffer[new_target_head % target.buffer.len], .Unordered); | |
new_target_head +%= 1; | |
@atomicStore(*Runnable, &self.buffer[new_tail % self.buffer.len], runnable, .Unordered); | |
new_tail +%= 1; | |
} | |
if (@cmpxchgWeak( | |
usize, | |
&target.head, | |
target_head, | |
new_target_head, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
@atomicStore(usize, &self.tail, new_tail, .Release); | |
return first_runnable; | |
} | |
} | |
fn popAndStealUnbounded(self: *BoundedQueue, target: *UnboundedQueue) ?*Runnable { | |
var consumer = target.tryAcquireConsumer() orelse return null; | |
defer consumer.release(); | |
const first_runnable = consumer.pop() orelse return null; | |
var tail = self.tail; | |
var head = @atomicLoad(usize, &self.head, .Acquire); | |
var size = tail -% head; | |
while (size < self.buffer.len) { | |
const runnable = consumer.pop() orelse break; | |
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered); | |
tail +%= 1; | |
size += 1; | |
} | |
@atomicStore(usize, &self.tail, tail, .Release); | |
return first_runnable; | |
} | |
}; | |
pub const Runnable = struct { | |
next: ?*Runnable = null, | |
runFn: fn (*Runnable) void, | |
pub fn run(self: *Runnable) void { | |
return (self.runFn)(self); | |
} | |
}; | |
pub const Batch = struct { | |
head: ?*Runnable = null, | |
tail: *Runnable = undefined, | |
size: usize = 0, | |
pub fn from(batchable: anytype) Batch { | |
return switch (@TypeOf(batchable)) { | |
Batch => batchable, | |
?*Runnable => from(batchable orelse return Batch{}), | |
*Runnable => { | |
batchable.next = null; | |
return Batch{ | |
.head = batchable, | |
.tail = batchable, | |
.size = 1, | |
}; | |
}, | |
else => |typ| @compileError(@typeName(typ) ++ | |
" cannot be converted into " ++ | |
@typeName(Batch)), | |
}; | |
} | |
pub fn isEmpty(self: Batch) bool { | |
return self.head == null; | |
} | |
pub const push = pushBack; | |
pub fn pushBack(self: *Batch, batchable: anytype) void { | |
const batch = from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (self.isEmpty()) { | |
self.* = batch; | |
} else { | |
self.tail.next = batch.head; | |
self.tail = batch.tail; | |
self.size += batch.size; | |
} | |
} | |
pub fn pushFront(self: *Batch, batchable: anytype) void { | |
const batch = from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (self.isEmpty()) { | |
self.* = batch; | |
} else { | |
batch.tail.next = self.head; | |
self.head = batch.head; | |
self.size += batch.size; | |
} | |
} | |
pub const pop = popFront; | |
pub fn popFront(self: *Batch) ?*Runnable { | |
const runnable = self.head orelse return null; | |
self.head = runnable.next; | |
self.size -= 1; | |
return runnable; | |
} | |
}; | |
const Semaphore = struct { | |
lock: Mutex = .{}, | |
permits: usize = 0, | |
waiters: ?*Waiter = null, | |
const Waiter = struct { | |
next: ?*Waiter = null, | |
tail: *Waiter = undefined, | |
event: Event = .{}, | |
permits: usize, | |
}; | |
fn init(permits: usize) Semaphore { | |
return .{ .permits = permits }; | |
} | |
fn wait(self: *Semaphore, permits: usize) void { | |
const held = self.lock.acquire(); | |
if (self.permits >= permits) { | |
self.permits -= permits; | |
held.release(); | |
return; | |
} | |
var waiter = Waiter{ .permits = permits }; | |
if (self.waiters) |head| { | |
head.tail.next = &waiter; | |
head.tail = &waiter; | |
} else { | |
self.waiters = &waiter; | |
waiter.tail = &waiter; | |
} | |
held.release(); | |
waiter.event.wait(); | |
} | |
fn post(self: *Semaphore, permits: usize) error{Overflow}!void { | |
var waiters: ?*Waiter = null; | |
{ | |
const held = self.lock.acquire(); | |
defer held.release(); | |
if (@addWithOverflow(usize, self.permits, permits, &self.permits)) | |
return error.Overflow; | |
while (self.waiters) |waiter| { | |
if (waiter.permits > self.permits) | |
break; | |
self.waiters = waiter.next; | |
if (self.waiters) |new_waiter| | |
new_waiter.tail = waiter.tail; | |
self.permits -= waiter.permits; | |
waiter.next = waiters; | |
waiters = waiter; | |
} | |
} | |
while (waiters) |waiter| { | |
waiters = waiter.next; | |
waiter.event.notify(); | |
} | |
} | |
}; | |
const Mutex = if (std.builtin.os.tag == .windows) | |
struct { | |
srwlock: usize = 0, | |
pub fn acquire(self: *Mutex) Held { | |
AcquireSRWLockExclusive(&self.srwlock); | |
return Mutex{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
ReleaseSRWLockExclusive(&self.mutex.srwlock); | |
} | |
}; | |
extern "kernel32" fn AcquireSRWLockExclusive( | |
srwlock: *?system.PVOID, | |
) callconv(system.WINAPI) void; | |
extern "kernel32" fn ReleaseSRWLockExclusive( | |
srwlock: *?system.PVOID, | |
) callconv(system.WINAPI) void; | |
} | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
lock: u32 = 0, | |
pub fn acquire(self: *Mutex) Held { | |
os_unfair_lock_lock(&self.lock); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
os_unfair_lock_unlock(&self.mutex.lock); | |
} | |
}; | |
extern "c" fn os_unfair_lock_lock( | |
unfair_lock: *u32, | |
) callconv(.C) void; | |
extern "c" fn os_unfair_lock_unlock( | |
unfair_lock: *u32, | |
) callconv(.C) void; | |
} | |
else if (std.builtin.os.tag == .linux) | |
struct { | |
state: i32 = UNLOCKED, | |
const UNLOCKED: i32 = 0; | |
const LOCKED: i32 = 1; | |
const WAITING: i32 = 2; | |
pub fn acquire(self: *Mutex) Held { | |
const state = @atomicRmw(i32, &self.state, .Xchg, LOCKED, .Acquire); | |
if (state != UNLOCKED) | |
self.acquireSlow(state); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
switch (@atomicRmw(i32, &self.mutex.state, .Xchg, UNLOCKED, .Release)) { | |
UNLOCKED => unreachable, // unlocked an unlocked mutex | |
LOCKED => {}, | |
WAITING => self.mutex.releaseSlow(), | |
else => unreachable, | |
} | |
} | |
}; | |
fn acquireSlow(self: *Mutex, current_state: i32) void { | |
@setCold(true); | |
var wait_state = current_state; | |
while (true) { | |
var spin: u8 = 0; | |
while (spin < 5) : (spin += 1) { | |
switch (@atomicLoad(i32, &self.state, .Monotonic)) { | |
UNLOCKED => _ = @cmpxchgWeak( | |
i32, | |
&self.state, | |
UNLOCKED, | |
wait_state, | |
.Acquire, | |
.Monotonic, | |
) orelse return, | |
LOCKED => {}, | |
WAITING => break, | |
else => unreachable, | |
} | |
if (spin < 4) { | |
var pause: u8 = 30; | |
while (pause > 0) : (pause -= 1) | |
yieldCpu(); | |
} else { | |
yieldThread(); | |
} | |
} | |
const state = @atomicRmw(i32, &self.state, .Xchg, WAITING, .Acquire); | |
if (state == UNLOCKED) | |
return; | |
wait_state = WAITING; | |
switch (system.getErrno(system.futex_wait( | |
&self.state, | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT, | |
WAITING, | |
null, | |
))) { | |
0 => {}, | |
system.EINTR => {}, | |
system.EAGAIN => {}, | |
else => unreachable, | |
} | |
} | |
} | |
fn releaseSlow(self: *Mutex) void { | |
@setCold(true); | |
while (true) { | |
return switch (system.getErrno(system.futex_wake( | |
&self.state, | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE, | |
@as(i32, 1), | |
))) { | |
0 => {}, | |
system.EINTR => continue, | |
system.EFAULT => {}, | |
else => unreachable, | |
}; | |
} | |
} | |
} | |
else | |
struct { | |
locked: bool = false, | |
pub fn acquire(self: *Mutex) Held { | |
while (@atomicRmw(bool, &self.locked, .Xchg, true, .Acquire)) | |
yieldThread(); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
@atomicStore(bool, &self.mutex.locked, false, .Release); | |
} | |
}; | |
}; | |
const Event = if (std.builtin.os.tag == .windows) | |
struct { | |
key: u32 = undefined, | |
pub fn wait(self: *Event) void { | |
const status = NtWaitForKeyedEvent(null, &self.key, system.FALSE, null); | |
std.debug.assert(status == .SUCCESS); | |
} | |
pub fn notify(self: *Event) void { | |
const status = NtReleaseKeyedEvent(null, &self.key, system.FALSE, null); | |
std.debug.assert(status == .SUCCESS); | |
} | |
extern "NtDll" fn NtWaitForKeyedEvent( | |
handle: ?system.HANDLE, | |
key: ?*const u32, | |
alertable: system.BOOLEAN, | |
timeout: ?*const system.LARGE_INTEGER, | |
) callconv(system.WINAPI) system.NTSTATUS; | |
extern "NtDll" fn NtReleaseKeyedEvent( | |
handle: ?system.HANDLE, | |
key: ?*const u32, | |
alertable: system.BOOLEAN, | |
timeout: ?*const system.LARGE_INTEGER, | |
) callconv(system.WINAPI) system.NTSTATUS; | |
} | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
state: enum(u32) { | |
pending = 0, | |
notified, | |
} = .pending, | |
pub fn wait(self: *Event) void { | |
while (true) { | |
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) { | |
.pending => {}, | |
.notified => { | |
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic); | |
return; | |
}, | |
} | |
const status = __ulock_wait( | |
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO, | |
@ptrCast(?*const c_void, &self.state), | |
@enumToInt(@TypeOf(self.state).pending), | |
~@as(u32, 0), | |
); | |
if (status < 0) { | |
switch (-status) { | |
system.EINTR => {}, | |
else => unreachable, | |
} | |
} | |
} | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release); | |
while (true) { | |
const status = __ulock_wake( | |
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO, | |
@ptrCast(?*const c_void, &self.state), | |
@as(u32, 0), | |
); | |
if (status < 0) { | |
switch (-status) { | |
system.ENOENT => {}, | |
system.EINTR => continue, | |
else => unreachable, | |
} | |
} | |
return; | |
} | |
} | |
const ULF_NO_ERRNO = 0x1000000; | |
const UL_COMPARE_AND_WAIT = 0x1; | |
extern "c" fn __ulock_wait( | |
operation: u32, | |
address: ?*const c_void, | |
value: u64, | |
timeout_us: u32, | |
) callconv(.C) c_int; | |
extern "c" fn __ulock_wake( | |
operation: u32, | |
address: ?*const c_void, | |
value: u64, | |
) callconv(.C) c_int; | |
} | |
else if (std.builtin.os.tag == .linux) | |
struct { | |
state: enum(i32) { | |
pending, | |
notified, | |
} = .pending, | |
pub fn wait(self: *Event) void { | |
while (true) { | |
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) { | |
.pending => {}, | |
.notified => { | |
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic); | |
return; | |
}, | |
} | |
switch (system.getErrno(system.futex_wait( | |
@ptrCast(*const i32, &self.state), | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT, | |
@enumToInt(@TypeOf(self.state).pending), | |
null, | |
))) { | |
0 => {}, | |
system.EINTR => {}, | |
system.EAGAIN => {}, | |
else => unreachable, | |
} | |
} | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release); | |
while (true) { | |
return switch (system.getErrno(system.futex_wake( | |
@ptrCast(*const i32, &self.state), | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE, | |
@as(i32, 1), | |
))) { | |
0 => {}, | |
system.EINTR => continue, | |
system.EFAULT => {}, | |
else => unreachable, | |
}; | |
} | |
} | |
} | |
else | |
struct { | |
notified: bool = false, | |
pub fn wait(self: *Event) void { | |
while (!@atomicLoad(bool, &self.notified, .Acquire)) | |
yieldThread(); | |
@atomicStore(bool, &self.notified, false, .Monotonic); | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(bool, &self.notified, true, .Release); | |
} | |
}; | |
const yieldThread = if (std.builtin.os.tag == .windows) | |
struct { | |
fn yield() void { | |
system.kernel32.Sleep(0); | |
} | |
}.yield | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
fn yield() void { | |
_ = thread_switch(MACH_PORT_NULL, SWITCH_OPTION_DEPRESS, 1); | |
} | |
const MACH_PORT_NULL = 0; | |
const SWITCH_OPTION_DEPRESS = 1; | |
// https://www.gnu.org/software/hurd/gnumach-doc/Hand_002dOff-Scheduling.html | |
extern "c" fn thread_switch( | |
thread: usize, | |
options: c_int, | |
timeout_ms: c_int, | |
) callconv(.C) c_int; | |
}.yield | |
else if (std.builtin.os.tag == .linux or std.builtin.link_libc) | |
struct { | |
fn yield() void { | |
_ = system.sched_yield(); | |
} | |
}.yield | |
else | |
yieldCpu; | |
fn yieldCpu() void { | |
switch (std.builtin.arch) { | |
.i386, .x86_64 => asm volatile ("pause"), | |
.arm, .aarch64 => asm volatile ("yield"), | |
else => {}, | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const linux = std.os.linux; | |
pub fn main() !void { | |
var ring: Ring = undefined; | |
try ring.init(); | |
defer ring.deinit(); | |
var server: Server = undefined; | |
try server.init(12345); | |
defer server.deinit(); | |
var frame = async server.run(&ring); | |
while (true) { | |
try ring.poll(); | |
} | |
} | |
const Ring = struct { | |
inner: linux.IO_Uring, | |
queue: std.TailQueue(void), | |
fn init(self: *Ring) !void { | |
self.inner = try linux.IO_Uring.init(512, 0); | |
self.queue = .{}; | |
} | |
fn deinit(self: *Ring) void { | |
self.inner.deinit(); | |
} | |
const Completion = struct { | |
onComplete: anyframe, | |
ring_queue: std.TailQueue(void).Node = .{ .data = {} }, | |
result: i32 = undefined, | |
}; | |
fn flushCompletions(self: *Ring) void { | |
var chunk: [256]linux.io_uring_cqe = undefined; | |
while (true) { | |
const found = self.inner.copy_cqes(&chunk, 0) catch unreachable; | |
if (found == 0) | |
break; | |
for (chunk[0..found]) |cqe| { | |
const completion = @intToPtr(*Ring.Completion, @intCast(usize, cqe.user_data)); | |
completion.result = cqe.res; | |
self.queue.append(&completion.ring_queue); | |
} | |
} | |
} | |
fn getSubmission(self: *Ring) *linux.io_uring_sqe { | |
while (true) { | |
return self.inner.get_sqe() catch { | |
var completion = Ring.Completion{ | |
.onComplete = @frame(), | |
}; | |
self.queue.append(&completion.ring_queue); | |
suspend; | |
continue; | |
}; | |
} | |
} | |
fn poll(self: *Ring) !void { | |
while (self.queue.popFirst()) |node| { | |
const completion = @fieldParentPtr(Completion, "ring_queue", node); | |
resume completion.onComplete; | |
} | |
_ = try self.inner.submit_and_wait(1); | |
self.flushCompletions(); | |
} | |
}; | |
const Server = struct { | |
fd: std.os.socket_t, | |
gpa: std.heap.GeneralPurposeAllocator(.{}), | |
clients: std.TailQueue(void), | |
pub fn init(self: *Server, comptime port: u16) !void { | |
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
errdefer std.os.close(self.fd); | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(self.fd, 128); | |
self.gpa = .{}; | |
self.clients = .{}; | |
std.debug.warn("Listening on :{}\n", .{port}); | |
} | |
pub fn deinit(self: *Server) void { | |
// TODO: cancel outstanding accept call? | |
// TODO: kill current clients? | |
std.os.close(self.fd); | |
std.debug.warn("Server closed\n", .{}); | |
} | |
pub fn run(self: *Server, ring: *Ring) !void { | |
errdefer self.deinit(); | |
while (true) { | |
const result = self.accept(ring); | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => continue, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const client_fd: std.os.socket_t = result; | |
const client = Client.init(self, client_fd) catch |err| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to start client: {}\n", .{client_fd}); | |
continue; | |
}; | |
client.frame = async client.run(ring); | |
self.clients.append(&client.server_clients); | |
} | |
} | |
fn accept(self: *Server, ring: *Ring) i32 { | |
const sqe = ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .ACCEPT; | |
sqe.fd = self.fd; | |
sqe.rw_flags = std.os.SOCK_CLOEXEC; | |
var completion = Ring.Completion{ | |
.onComplete = @frame(), | |
}; | |
sqe.user_data = @ptrToInt(&completion); | |
suspend; | |
return completion.result; | |
} | |
}; | |
const Client = struct { | |
server: *Server, | |
server_clients: std.TailQueue(void).Node, | |
fd: std.os.socket_t, | |
reader: Reader, | |
writer: Writer, | |
is_closed: bool, | |
frame: @Frame(run), | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
pub fn init(server: *Server, fd: std.os.socket_t) !*Client { | |
const allocator = &server.gpa.allocator; | |
const self = try allocator.create(Client); | |
errdefer allocator.destroy(self); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
self.* = .{ | |
.server = server, | |
.server_clients = .{ .data = {} }, | |
.fd = fd, | |
.reader = .{ .fd = fd }, | |
.writer = .{ .fd = fd }, | |
.is_closed = false, | |
.frame = undefined, | |
}; | |
return self; | |
} | |
pub fn deinit(self: *Client) void { | |
std.os.close(self.fd); | |
self.server.clients.remove(&self.server_clients); | |
const allocator = &self.server.gpa.allocator; | |
suspend { | |
allocator.destroy(self); | |
} | |
} | |
pub fn run(self: *Client, ring: *Ring) Reader.RunError!void { | |
Reader.run(self, ring) catch |err| switch (err) { | |
error.ConnectionResetByPeer => {}, | |
else => return err, | |
}; | |
} | |
const Reader = struct { | |
state: enum { read, stop } = .stop, | |
buffer: Buffer = Buffer.init(), | |
fd: i32, | |
const Buffer = std.fifo.LinearFifo(u8, .{ .Static = 4096 }); | |
const ReaderError = std.os.RecvFromError || error{ | |
closed, | |
Eof, | |
HttpRequestTooLarge, | |
}; | |
const RunError = ReaderError; | |
fn run(client: *Client, ring: *Ring) RunError!void { | |
const self = &client.reader; | |
// const client = @fieldParentPtr(Client, "reader", self); | |
errdefer { | |
self.state = .stop; | |
client.is_closed = true; | |
if (client.writer.state == .idle) | |
client.writer.state = .stop; | |
if (client.writer.state == .stop) | |
client.deinit(); | |
} | |
while (true) { | |
const result = self.read(ring); | |
if (client.is_closed) | |
return error.closed; | |
switch (if (result < 0) @intCast(u12, -result) else @as(u12, 0)) { | |
0 => {}, | |
std.os.EINTR => continue, | |
std.os.ENOMEM => return error.SystemResources, | |
std.os.ECONNRESET => return error.ConnectionResetByPeer, | |
else => |err| return std.os.unexpectedErrno(err), | |
} | |
const bytes = @intCast(usize, result); | |
if (bytes == 0) | |
return error.Eof; | |
self.buffer.update(bytes); | |
while (true) { | |
self.buffer.realign(); | |
if (std.mem.indexOf(u8, self.buffer.readableSlice(0), HTTP_CLRF)) |parsed| { | |
self.buffer.discard(bytes); | |
try Writer.write(client, ring, HTTP_RESPONSE); | |
continue; | |
} | |
if (self.buffer.writableLength() == 0) | |
return error.HttpRequestTooLarge; | |
break; | |
} | |
} | |
} | |
fn read(self: *Reader, ring: *Ring) i32 { | |
const sqe = ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .RECV; | |
sqe.fd = self.fd; | |
// sqe.fd = @fieldParentPtr(Client, "reader", self).fd; | |
const slice = self.buffer.writableSlice(0); | |
sqe.addr = @ptrToInt(slice.ptr); | |
sqe.len = @truncate(u32, slice.len); | |
self.state = .read; | |
var completion = Ring.Completion{ | |
.onComplete = @frame(), | |
}; | |
sqe.user_data = @ptrToInt(&completion); | |
suspend; | |
return completion.result; | |
} | |
}; | |
const Writer = struct { | |
state: enum { write, idle, stop } = .idle, | |
buffer: Buffer = Buffer.init(), | |
fd: i32, | |
const Buffer = std.fifo.LinearFifo(u8, .{ .Static = 4096 }); | |
fn flush(self: *Writer, ring: *Ring) !void { | |
switch (self.state) { | |
.idle => { | |
self.state = .write; | |
while (true) { | |
const result = self.rawWrite(ring); | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => continue, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const bytes_written = @intCast(usize, result); | |
self.buffer.discard(bytes_written); | |
break; | |
} | |
self.state = .idle; | |
}, | |
.write => { | |
@panic("TODO: wait for existing write to finish"); | |
}, | |
.stop => return error.closed, | |
} | |
} | |
fn write(client: *Client, ring: *Ring, src: []const u8) !void { | |
const self = &client.writer; | |
// const client = @fieldParentPtr(Client, "writer", self); | |
errdefer { | |
self.state = .stop; | |
client.is_closed = true; | |
if (client.reader.state == .stop) | |
client.deinit(); | |
} | |
var src_left = src; | |
while (src_left.len > 0) { | |
const writable_slice = self.buffer.writableSlice(0); | |
if (writable_slice.len == 0) { | |
try self.flush(ring); | |
continue; | |
} | |
const n = std.math.min(writable_slice.len, src_left.len); | |
std.mem.copy(u8, writable_slice, src_left[0..n]); | |
self.buffer.update(n); | |
src_left = src_left[n..]; | |
} | |
while (self.buffer.readableLength() > 0) { | |
try self.flush(ring); | |
} | |
} | |
fn rawWrite(self: *Writer, ring: *Ring) i32 { | |
const sqe = ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .SEND; | |
sqe.fd = self.fd; | |
// sqe.fd = @fieldParentPtr(Client, "writer", self).fd; | |
const slice = self.buffer.readableSlice(0); | |
sqe.addr = @ptrToInt(slice.ptr); | |
sqe.len = @truncate(u32, slice.len); | |
var completion = Ring.Completion{ | |
.onComplete = @frame(), | |
}; | |
sqe.user_data = @ptrToInt(&completion); | |
suspend; | |
return completion.result; | |
} | |
}; | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const linux = std.os.linux; | |
pub fn main() !void { | |
var ring: Ring = undefined; | |
try ring.init(); | |
defer ring.deinit(); | |
var server: Server = undefined; | |
try server.init(&ring, 12345); | |
defer server.deinit(); | |
while (true) { | |
try ring.poll(); | |
} | |
} | |
const Ring = struct { | |
inner: linux.IO_Uring, | |
head: ?*Completion, | |
tail: ?*Completion, | |
fn init(self: *Ring) !void { | |
self.inner = try linux.IO_Uring.init(512, 0); | |
self.head = null; | |
self.tail = null; | |
} | |
fn deinit(self: *Ring) void { | |
self.inner.deinit(); | |
} | |
const Completion = struct { | |
result: i32 = undefined, | |
next: ?*Completion = null, | |
onComplete: fn (*Ring, *Completion) void, | |
}; | |
fn flushCompletions(self: *Ring) !void { | |
var chunk: [256]linux.io_uring_cqe = undefined; | |
while (true) { | |
const found = try self.inner.copy_cqes(&chunk, 0); | |
if (found == 0) | |
break; | |
for (chunk[0..found]) |cqe| { | |
const completion = @intToPtr(*Ring.Completion, @intCast(usize, cqe.user_data)); | |
completion.result = cqe.res; | |
if (self.head == null) | |
self.head = completion; | |
if (self.tail) |tail| | |
tail.next = completion; | |
completion.next = null; | |
self.tail = completion; | |
} | |
} | |
} | |
fn getSubmission(self: *Ring) !*linux.io_uring_sqe { | |
while (true) { | |
return self.inner.get_sqe() catch { | |
try self.flushCompletions(); | |
_ = try self.inner.submit(); | |
continue; | |
}; | |
} | |
} | |
fn poll(self: *Ring) !void { | |
while (self.head) |completion| { | |
self.head = completion.next; | |
if (self.head == null) | |
self.tail = null; | |
(completion.onComplete)(self, completion); | |
} | |
_ = try self.inner.submit_and_wait(1); | |
try self.flushCompletions(); | |
} | |
}; | |
const Server = struct { | |
fd: std.os.socket_t, | |
completion: Ring.Completion, | |
gpa: std.heap.GeneralPurposeAllocator(.{}), | |
fn init(self: *Server, ring: *Ring, comptime port: u16) !void { | |
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
errdefer std.os.close(self.fd); | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(self.fd, 128); | |
self.gpa = .{}; | |
self.completion = .{ .onComplete = Server.onCompletion }; | |
try self.submitAccept(ring); | |
std.debug.warn("Listening on :{}\n", .{port}); | |
} | |
fn deinit(self: *Server) void { | |
std.os.close(self.fd); | |
std.debug.warn("Server closed\n", .{}); | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Server, "completion", completion); | |
self.process(ring, completion.result) catch self.deinit(); | |
} | |
fn process(self: *Server, ring: *Ring, result: i32) !void { | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitAccept(ring); | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const client_fd: std.os.socket_t = result; | |
Client.init(self, ring, client_fd) catch |err| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to start client: {}\n", .{client_fd}); | |
}; | |
try self.submitAccept(ring); | |
} | |
fn submitAccept(self: *Server, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .ACCEPT; | |
sqe.fd = self.fd; | |
sqe.rw_flags = std.os.SOCK_CLOEXEC; | |
sqe.user_data = @ptrToInt(&self.completion); | |
} | |
}; | |
const Client = struct { | |
server: *Server, | |
fd: std.os.socket_t, | |
reader: Reader, | |
writer: Writer, | |
is_closed: bool, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn init(server: *Server, ring: *Ring, fd: std.os.socket_t) !void { | |
const allocator = &server.gpa.allocator; | |
const self = try allocator.create(Client); | |
errdefer allocator.destroy(self); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
self.fd = fd; | |
self.server = server; | |
self.is_closed = false; | |
try self.reader.init(ring); | |
try self.writer.init(ring); | |
} | |
fn deinit(self: *Client) void { | |
std.os.close(self.fd); | |
const allocator = &self.server.gpa.allocator; | |
allocator.destroy(self); | |
} | |
const Reader = struct { | |
state: enum { read, stop } = .stop, | |
completion: Ring.Completion, | |
buffer: Buffer = Buffer.init(), | |
const Buffer = std.fifo.LinearFifo(u8, .{ .Static = 4096 }); | |
fn init(self: *Reader, ring: *Ring) !void { | |
const client = @fieldParentPtr(Client, "reader", self); | |
self.* = .{ .completion = .{ .onComplete = onCompletion } }; | |
try self.submitRead(ring); | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Reader, "completion", completion); | |
const client = @fieldParentPtr(Client, "reader", self); | |
self.process(client, ring, completion.result) catch |err| { | |
if (self.state == .stop and client.writer.state == .stop) | |
client.deinit(); | |
}; | |
} | |
fn process(self: *Reader, client: *Client, ring: *Ring, result: i32) !void { | |
errdefer { | |
self.state = .stop; | |
client.is_closed = true; | |
if (client.writer.state == .idle) | |
client.writer.state = .stop; | |
} | |
if (client.is_closed) | |
return error.closed; | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitRead(ring); | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const bytes = @intCast(usize, result); | |
if (bytes == 0) | |
return error.Eof; | |
self.buffer.update(bytes); | |
while (true) { | |
self.buffer.realign(); | |
if (std.mem.indexOf(u8, self.buffer.readableSlice(0), HTTP_CLRF)) |parsed| { | |
self.buffer.discard(bytes); | |
client.writer.buffer.writeAssumeCapacity(HTTP_RESPONSE); | |
if (client.writer.state == .idle) { | |
try client.writer.submitWrite(ring); | |
} | |
continue; | |
} | |
if (self.buffer.writableLength() == 0) | |
return error.HttpRequestTooLarge; | |
try self.submitRead(ring); | |
return; | |
} | |
} | |
fn submitRead(self: *Reader, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .RECV; | |
sqe.fd = @fieldParentPtr(Client, "reader", self).fd; | |
const slice = self.buffer.writableSlice(0); | |
sqe.addr = @ptrToInt(slice.ptr); | |
sqe.len = @truncate(u32, slice.len); | |
sqe.user_data = @ptrToInt(&self.completion); | |
self.state = .read; | |
} | |
}; | |
const Writer = struct { | |
state: enum { write, idle, stop } = .idle, | |
completion: Ring.Completion, | |
buffer: Buffer = Buffer.init(), | |
const Buffer = std.fifo.LinearFifo(u8, .{ .Static = NUM_RESPONSE_CHUNKS * HTTP_RESPONSE.len }); | |
const NUM_RESPONSE_CHUNKS = 128; | |
fn init(self: *Writer, ring: *Ring) !void { | |
self.* = .{ .completion = .{ .onComplete = onCompletion } }; | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Writer, "completion", completion); | |
const client = @fieldParentPtr(Client, "writer", self); | |
self.process(client, ring, completion.result) catch |err| { | |
if (self.state == .stop and client.reader.state == .stop) | |
client.deinit(); | |
}; | |
} | |
fn process(self: *Writer, client: *Client, ring: *Ring, result: i32) !void { | |
errdefer { | |
self.state = .stop; | |
client.is_closed = true; | |
} | |
if (client.is_closed) | |
return error.closed; | |
if (self.state == .idle) unreachable; | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitWrite(ring); | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const bytes_written = @intCast(usize, result); | |
self.buffer.discard(bytes_written); | |
if (self.buffer.readableLength() == 0) { | |
self.state = .idle; | |
return; | |
} | |
try self.submitWrite(ring); | |
} | |
fn submitWrite(self: *Writer, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .SEND; | |
sqe.fd = @fieldParentPtr(Client, "writer", self).fd; | |
const slice = self.buffer.readableSlice(0); | |
sqe.addr = @ptrToInt(slice.ptr); | |
sqe.len = @truncate(u32, slice.len); | |
sqe.user_data = @ptrToInt(&self.completion); | |
self.state = .write; | |
} | |
}; | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const linux = std.os.linux; | |
pub fn main() !void { | |
var ring: Ring = undefined; | |
try ring.init(); | |
defer ring.deinit(); | |
var server: Server = undefined; | |
try server.init(&ring, 12345); | |
defer server.deinit(); | |
while (true) { | |
try ring.poll(); | |
} | |
} | |
const Ring = struct { | |
inner: linux.IO_Uring, | |
head: ?*Completion, | |
tail: ?*Completion, | |
fn init(self: *Ring) !void { | |
self.inner = try linux.IO_Uring.init(512, 0); | |
self.head = null; | |
self.tail = null; | |
} | |
fn deinit(self: *Ring) void { | |
self.inner.deinit(); | |
} | |
const Completion = struct { | |
result: i32 = undefined, | |
next: ?*Completion = null, | |
onComplete: fn (*Ring, *Completion) void, | |
}; | |
fn flushCompletions(self: *Ring) !void { | |
var chunk: [256]linux.io_uring_cqe = undefined; | |
while (true) { | |
const found = try self.inner.copy_cqes(&chunk, 0); | |
if (found == 0) | |
break; | |
for (chunk[0..found]) |cqe| { | |
const completion = @intToPtr(*Ring.Completion, @intCast(usize, cqe.user_data)); | |
completion.result = cqe.res; | |
if (self.head == null) | |
self.head = completion; | |
if (self.tail) |tail| | |
tail.next = completion; | |
completion.next = null; | |
self.tail = completion; | |
} | |
} | |
} | |
fn getSubmission(self: *Ring) !*linux.io_uring_sqe { | |
while (true) { | |
return self.inner.get_sqe() catch { | |
try self.flushCompletions(); | |
_ = try self.inner.submit(); | |
continue; | |
}; | |
} | |
} | |
fn poll(self: *Ring) !void { | |
while (self.head) |completion| { | |
self.head = completion.next; | |
if (self.head == null) | |
self.tail = null; | |
(completion.onComplete)(self, completion); | |
} | |
_ = try self.inner.submit_and_wait(1); | |
try self.flushCompletions(); | |
} | |
}; | |
const Server = struct { | |
fd: std.os.socket_t, | |
completion: Ring.Completion, | |
gpa: std.heap.GeneralPurposeAllocator(.{}), | |
fn init(self: *Server, ring: *Ring, comptime port: u16) !void { | |
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
errdefer std.os.close(self.fd); | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(self.fd, 128); | |
self.gpa = .{}; | |
self.completion = .{ .onComplete = Server.onCompletion }; | |
try self.submitAccept(ring); | |
std.debug.warn("Listening on :{}\n", .{port}); | |
} | |
fn deinit(self: *Server) void { | |
std.os.close(self.fd); | |
std.debug.warn("Server closed\n", .{}); | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Server, "completion", completion); | |
self.process(ring, completion.result) catch self.deinit(); | |
} | |
fn process(self: *Server, ring: *Ring, result: i32) !void { | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitAccept(ring); | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const client_fd: std.os.socket_t = result; | |
Client.init(self, ring, client_fd) catch |err| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to start client: {}\n", .{client_fd}); | |
}; | |
try self.submitAccept(ring); | |
} | |
fn submitAccept(self: *Server, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .ACCEPT; | |
sqe.fd = self.fd; | |
sqe.rw_flags = std.os.SOCK_CLOEXEC; | |
sqe.user_data = @ptrToInt(&self.completion); | |
} | |
}; | |
const Client = struct { | |
server: *Server, | |
fd: std.os.socket_t, | |
reader: Reader, | |
writer: Writer, | |
is_closed: bool, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn init(server: *Server, ring: *Ring, fd: std.os.socket_t) !void { | |
const allocator = &server.gpa.allocator; | |
const self = try allocator.create(Client); | |
errdefer allocator.destroy(self); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
self.fd = fd; | |
self.server = server; | |
self.is_closed = false; | |
try self.reader.init(ring); | |
try self.writer.init(ring); | |
} | |
fn deinit(self: *Client) void { | |
std.os.close(self.fd); | |
const allocator = &self.server.gpa.allocator; | |
allocator.destroy(self); | |
} | |
const Reader = struct { | |
state: enum { read, stop } = .stop, | |
completion: Ring.Completion, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
iovec: std.os.iovec = undefined, | |
fn init(self: *Reader, ring: *Ring) !void { | |
const client = @fieldParentPtr(Client, "reader", self); | |
self.* = .{ .completion = .{ .onComplete = onCompletion } }; | |
self.iovec.iov_base = &self.recv_buffer; | |
self.iovec.iov_len = self.recv_buffer.len; | |
try self.submitRead(ring); | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Reader, "completion", completion); | |
const client = @fieldParentPtr(Client, "reader", self); | |
self.process(client, ring, completion.result) catch |err| { | |
if (self.state == .stop and client.writer.state == .stop) | |
client.deinit(); | |
}; | |
} | |
fn process(self: *Reader, client: *Client, ring: *Ring, result: i32) !void { | |
errdefer { | |
self.state = .stop; | |
client.is_closed = true; | |
if (client.writer.state == .idle) | |
client.writer.state = .stop; | |
} | |
if (client.is_closed) | |
return error.closed; | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitRead(ring); | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const bytes = @intCast(usize, result); | |
self.recv_bytes += bytes; | |
if (bytes == 0) | |
return error.Eof; | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len)..request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
client.writer.send_bytes += HTTP_RESPONSE.len; | |
if (client.writer.state == .idle) { | |
client.writer.state = .write; | |
try client.writer.process(client, ring, 0); | |
} | |
continue; | |
} | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
self.iovec.iov_base = readable_buffer.ptr; | |
self.iovec.iov_len = readable_buffer.len; | |
try self.submitRead(ring); | |
return; | |
} | |
} | |
fn submitRead(self: *Reader, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
// sqe.opcode = .READV; | |
sqe.opcode = .RECV; | |
sqe.fd = @fieldParentPtr(Client, "reader", self).fd; | |
// sqe.addr = @ptrToInt(&self.iovec); | |
// sqe.len = 1; | |
sqe.addr = @ptrToInt(self.iovec.iov_base); | |
sqe.len = @truncate(u32, self.iovec.iov_len); | |
sqe.user_data = @ptrToInt(&self.completion); | |
self.state = .read; | |
} | |
}; | |
const Writer = struct { | |
state: enum { write, idle, stop } = .idle, | |
completion: Ring.Completion, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
iovec: std.os.iovec_const = undefined, | |
fn init(self: *Writer, ring: *Ring) !void { | |
self.* = .{ .completion = .{ .onComplete = onCompletion } }; | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Writer, "completion", completion); | |
const client = @fieldParentPtr(Client, "writer", self); | |
self.process(client, ring, completion.result) catch |err| { | |
if (self.state == .stop and client.reader.state == .stop) | |
client.deinit(); | |
}; | |
} | |
fn process(self: *Writer, client: *Client, ring: *Ring, result: i32) !void { | |
errdefer { | |
self.state = .stop; | |
client.is_closed = true; | |
} | |
if (client.is_closed) | |
return error.closed; | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitWrite(ring); | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const bytes_written = @intCast(usize, result); | |
self.send_bytes -= bytes_written; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
if (self.send_bytes == 0) { | |
self.state = .idle; | |
return; | |
} | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
self.iovec.iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
self.iovec.iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len - self.send_partial); | |
try self.submitWrite(ring); | |
} | |
fn submitWrite(self: *Writer, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
// sqe.opcode = .WRITEV; | |
sqe.opcode = .SEND; | |
sqe.fd = @fieldParentPtr(Client, "writer", self).fd; | |
// sqe.addr = @ptrToInt(&self.iovec); | |
// sqe.len = 1; | |
sqe.addr = @ptrToInt(self.iovec.iov_base); | |
sqe.len = @truncate(u32, self.iovec.iov_len); | |
sqe.user_data = @ptrToInt(&self.completion); | |
self.state = .write; | |
} | |
}; | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const linux = std.os.linux; | |
pub fn main() !void { | |
var ring: Ring = undefined; | |
try ring.init(); | |
defer ring.deinit(); | |
var server: Server = undefined; | |
try server.init(&ring, 12345); | |
defer server.deinit(); | |
while (true) { | |
try ring.poll(); | |
} | |
} | |
const Ring = struct { | |
inner: linux.IO_Uring, | |
head: ?*Completion, | |
tail: ?*Completion, | |
fn init(self: *Ring) !void { | |
self.inner = try linux.IO_Uring.init(512, 0); | |
self.head = null; | |
self.tail = null; | |
} | |
fn deinit(self: *Ring) void { | |
self.inner.deinit(); | |
} | |
const Completion = struct { | |
result: i32 = undefined, | |
next: ?*Completion = null, | |
onComplete: fn(*Ring, *Completion) void, | |
}; | |
fn flushCompletions(self: *Ring) !void { | |
var chunk: [256]linux.io_uring_cqe = undefined; | |
while (true) { | |
const found = try self.inner.copy_cqes(&chunk, 0); | |
if (found == 0) | |
break; | |
for (chunk[0..found]) |cqe| { | |
const completion = @intToPtr(*Ring.Completion, @intCast(usize, cqe.user_data)); | |
completion.result = cqe.res; | |
if (self.head == null) | |
self.head = completion; | |
if (self.tail) |tail| | |
tail.next = completion; | |
completion.next = null; | |
self.tail = completion; | |
} | |
} | |
} | |
fn getSubmission(self: *Ring) !*linux.io_uring_sqe { | |
while (true) { | |
return self.inner.get_sqe() catch { | |
try self.flushCompletions(); | |
_ = try self.inner.submit(); | |
continue; | |
}; | |
} | |
} | |
fn poll(self: *Ring) !void { | |
while (self.head) |completion| { | |
self.head = completion.next; | |
if (self.head == null) | |
self.tail = null; | |
(completion.onComplete)(self, completion); | |
} | |
_ = try self.inner.submit_and_wait(1); | |
try self.flushCompletions(); | |
} | |
fn submitPoll(self: *Ring, fd: std.os.fd_t, flags: u32, completion: *Completion) !void { | |
const sqe = try self.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .POLL_ADD; | |
sqe.fd = fd; | |
sqe.rw_flags = flags | linux.POLLERR | linux.POLLHUP; | |
sqe.user_data = @ptrToInt(completion); | |
} | |
}; | |
const Server = struct { | |
fd: std.os.socket_t, | |
completion: Ring.Completion, | |
gpa: std.heap.GeneralPurposeAllocator(.{}), | |
state: enum { poll, accept }, | |
fn init(self: *Server, ring: *Ring, comptime port: u16) !void { | |
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
errdefer std.os.close(self.fd); | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(self.fd, 128); | |
self.gpa = .{}; | |
self.state = .poll; | |
self.completion = .{ .onComplete = Server.onCompletion }; | |
try self.submitAccept(ring); | |
std.debug.warn("Listening on :{}\n", .{port}); | |
} | |
fn deinit(self: *Server) void { | |
std.os.close(self.fd); | |
std.debug.warn("Server closed\n", .{}); | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Server, "completion", completion); | |
self.process(ring, completion.result) catch self.deinit(); | |
} | |
fn process(self: *Server, ring: *Ring, result: i32) !void { | |
switch (self.state) { | |
.poll => { | |
if (result & (linux.POLLERR | linux.POLLHUP) != 0) | |
return error.Closed; | |
try self.submitAccept(ring); | |
}, | |
.accept => { | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitAccept(ring); | |
return; | |
}, | |
std.os.EAGAIN => { | |
try ring.submitPoll(self.fd, linux.POLLIN, &self.completion); | |
self.state = .poll; | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const client_fd: std.os.socket_t = result; | |
Client.init(self, ring, client_fd) catch |err| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to start client: {}\n", .{client_fd}); | |
}; | |
try self.submitAccept(ring); | |
}, | |
} | |
} | |
fn submitAccept(self: *Server, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .ACCEPT; | |
sqe.fd = self.fd; | |
sqe.rw_flags = std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC; | |
sqe.user_data = @ptrToInt(&self.completion); | |
self.state = .accept; | |
} | |
}; | |
const Client = struct { | |
server: *Server, | |
fd: std.os.socket_t, | |
reader: Reader, | |
writer: Writer, | |
is_closed: bool, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn init(server: *Server, ring: *Ring, fd: std.os.socket_t) !void { | |
const allocator = &server.gpa.allocator; | |
const self = try allocator.create(Client); | |
errdefer allocator.destroy(self); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
self.fd = fd; | |
self.server = server; | |
self.is_closed = false; | |
try self.reader.init(ring); | |
try self.writer.init(ring); | |
} | |
fn deinit(self: *Client) void { | |
std.os.close(self.fd); | |
const allocator = &self.server.gpa.allocator; | |
allocator.destroy(self); | |
} | |
const Reader = struct { | |
state: enum { poll, read, stop } = .stop, | |
completion: Ring.Completion, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
iovec: std.os.iovec = undefined, | |
fn init(self: *Reader, ring: *Ring) !void { | |
const client = @fieldParentPtr(Client, "reader", self); | |
self.* = .{ .completion = .{ .onComplete = onCompletion } }; | |
self.iovec.iov_base = &self.recv_buffer; | |
self.iovec.iov_len = self.recv_buffer.len; | |
try ring.submitPoll(client.fd, linux.POLLIN, &self.completion); | |
self.state = .poll; | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Reader, "completion", completion); | |
const client = @fieldParentPtr(Client, "reader", self); | |
self.process(client, ring, completion.result) catch |err| { | |
if (self.state == .stop and client.writer.state == .stop) | |
client.deinit(); | |
}; | |
} | |
fn process(self: *Reader, client: *Client, ring: *Ring, result: i32) !void { | |
errdefer { | |
self.state = .stop; | |
client.is_closed = true; | |
if (client.writer.state == .idle) | |
client.writer.state = .stop; | |
} | |
if (client.is_closed) | |
return error.closed; | |
switch (self.state) { | |
.poll => { | |
if (result & (linux.POLLERR | linux.POLLHUP) != 0) | |
return error.Closed; | |
try self.submitRead(ring); | |
}, | |
.read => { | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitRead(ring); | |
return; | |
}, | |
std.os.EAGAIN => { | |
try ring.submitPoll(client.fd, linux.POLLIN, &self.completion); | |
self.state = .poll; | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const bytes = @intCast(usize, result); | |
self.recv_bytes += bytes; | |
if (bytes == 0) | |
return error.Eof; | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
client.writer.send_bytes += HTTP_RESPONSE.len; | |
if (client.writer.state == .idle) { | |
client.writer.state = .write; | |
try client.writer.process(client, ring, 0); | |
} | |
continue; | |
} | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
self.iovec.iov_base = readable_buffer.ptr; | |
self.iovec.iov_len = readable_buffer.len; | |
try ring.submitPoll(client.fd, linux.POLLIN, &self.completion); | |
self.state = .poll; | |
return; | |
} | |
}, | |
else => {} | |
} | |
} | |
fn submitRead(self: *Reader, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .READV; | |
sqe.fd = @fieldParentPtr(Client, "reader", self).fd; | |
sqe.addr = @ptrToInt(&self.iovec); | |
sqe.len = 1; | |
sqe.user_data = @ptrToInt(&self.completion); | |
self.state = .read; | |
} | |
}; | |
const Writer = struct { | |
state: enum { poll, write, idle, stop } = .idle, | |
completion: Ring.Completion, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
iovec: std.os.iovec_const = undefined, | |
fn init(self: *Writer, ring: *Ring) !void { | |
self.* = .{ .completion = .{ .onComplete = onCompletion } }; | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Writer, "completion", completion); | |
const client = @fieldParentPtr(Client, "writer", self); | |
self.process(client, ring, completion.result) catch |err| { | |
if (self.state == .stop and client.reader.state == .stop) | |
client.deinit(); | |
}; | |
} | |
fn process(self: *Writer, client: *Client, ring: *Ring, result: i32) !void { | |
errdefer { | |
self.state = .stop; | |
client.is_closed = true; | |
} | |
if (client.is_closed) | |
return error.closed; | |
switch (self.state) { | |
.poll => { | |
if (result & (linux.POLLERR | linux.POLLHUP) != 0) | |
return error.Closed; | |
try self.submitWrite(ring); | |
}, | |
.write => { | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitWrite(ring); | |
return; | |
}, | |
std.os.EAGAIN => { | |
try ring.submitPoll(client.fd, linux.POLLOUT, &self.completion); | |
self.state = .poll; | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const bytes_written = @intCast(usize, result); | |
self.send_bytes -= bytes_written; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
if (self.send_bytes == 0) { | |
self.state = .idle; | |
return; | |
} | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
self.iovec.iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
self.iovec.iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len - self.send_partial); | |
try ring.submitPoll(client.fd, linux.POLLOUT, &self.completion); | |
self.state = .poll; | |
}, | |
else => {}, | |
} | |
} | |
fn submitWrite(self: *Writer, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .WRITEV; | |
sqe.fd = @fieldParentPtr(Client, "writer", self).fd; | |
sqe.addr = @ptrToInt(&self.iovec); | |
sqe.len = 1; | |
sqe.user_data = @ptrToInt(&self.completion); | |
self.state = .write; | |
} | |
}; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment