Skip to content

Instantly share code, notes, and snippets.

@rlapz
Created April 22, 2023 14:49
Show Gist options
  • Save rlapz/0d46d8ad2005ea7dcfca1bc7874bdb63 to your computer and use it in GitHub Desktop.
Save rlapz/0d46d8ad2005ea7dcfca1bc7874bdb63 to your computer and use it in GitHub Desktop.
fturing: zig implementation
const std = @import("std");
const assert = std.debug.assert;
const fmt = std.fmt;
const heap = std.heap;
const mem = std.mem;
const log = std.log;
const os = std.os;
const io = std.io;
const net = std.net;
const linux = os.linux;
const O = os.O;
const dprint = std.debug.print;
const upload_dir = "upload_dir";
const buffer_size = 8192;
const server_queue_depth = 32;
const server_clients_max = 1024;
const Packet = extern struct {
fsize: u64 align(1),
fname_len: u8 align(1),
fname: [255]u8 align(1),
const size = @sizeOf(Packet);
// compile time assertions
comptime {
// check packet size
assert(size == 264);
// check field offsets
assert(@offsetOf(Packet, "fsize") == 0);
assert(@offsetOf(Packet, "fname_len") == 8);
assert(@offsetOf(Packet, "fname") == 9);
}
fn setFname(self: *Packet, fname: []const u8) !void {
if (fname.len >= 255)
return error.FileNameTooLong;
mem.copy(u8, self.fname, fname);
self.fname[fname.len] = '\x00';
self.fname_len = fname.len;
}
fn getFname(self: *Packet) []const u8 {
return self.fname[0..self.fname_len];
}
fn setFsize(self: *Packet, fsize: u64) void {
self.fsize = mem.nativeToBig(u64, fsize);
}
fn getFsize(self: *Packet) u64 {
return mem.bigToNative(u64, self.fsize);
}
fn check(self: *Packet) !void {
if (self.fname_len == 0 or mem.indexOf(u8, &self.fname, "..") != null)
return error.InvalidFileProp;
}
};
const Uring = struct {
uring: linux.IO_Uring,
fn init(depth: u13) !Uring {
return .{
.uring = try linux.IO_Uring.init(depth, 0),
};
}
fn deinit(self: *Uring) void {
self.uring.deinit();
}
fn submitAndWait(self: *Uring) !void {
_ = try self.uring.submit_and_wait(1);
}
fn getCqes(self: *Uring, cqes: []linux.io_uring_cqe) !u32 {
return self.uring.copy_cqes(cqes, 1);
}
fn accept(self: *Uring, sock_fd: os.fd_t, udata: u64) void {
while (true) {
_ = self.uring.accept(udata, sock_fd, null, null, 0) catch {
// submit pending queue(s)
_ = self.uring.submit() catch
unreachable;
continue;
};
break;
}
}
fn recv(self: *Uring, fd: os.fd_t, buffer: []u8, udata: u64) void {
while (true) {
_ = self.uring.recv(udata, fd, .{ .buffer = buffer }, 0) catch {
// submit pending queue(s)
_ = self.uring.submit() catch
unreachable;
continue;
};
break;
}
}
fn write(self: *Uring, fd: os.fd_t, buffer: []const u8, offt: u64, udata: u64) void {
while (true) {
_ = self.uring.write(udata, fd, buffer, offt) catch {
// submit pending queue(s)
_ = self.uring.submit() catch
unreachable;
continue;
};
break;
}
}
};
const SClient = struct {
state: State,
sock_fd: os.fd_t,
file_fd: ?os.fd_t,
file_size: u64,
bytes: u64,
uring: *Uring,
packet: extern union {
pkt: Packet,
raw: [buffer_size]u8,
},
const State = enum {
prop,
recv,
write,
finish,
};
fn set(self: *SClient, sock_fd: os.fd_t, uring: *Uring) void {
self.state = .prop;
self.sock_fd = sock_fd;
self.file_fd = null;
self.file_size = 0;
self.bytes = 0;
self.uring = uring;
self.uring.recv(sock_fd, self.packet.raw[0..Packet.size], @ptrToInt(self));
}
fn unset(self: *SClient) void {
os.close(self.sock_fd);
if (self.file_fd) |fd|
os.close(fd);
}
fn handle(self: *SClient, res: i32) bool {
var ret = false;
defer if (!ret) {
self.unset();
log.info("closed connection: {}", .{self.sock_fd});
};
if (res == 0)
return ret;
const state = switch (self.state) {
.write => self.handleFileWrite(res),
.recv => self.handleFileRecv(res),
.prop => self.handleFileProp(res),
else => .finish,
};
if (state == .finish)
return ret;
self.state = state;
ret = true;
return ret;
}
fn handleFileProp(self: *SClient, res: i32) State {
const recvd = self.bytes + @intCast(u64, res);
if (recvd < Packet.size) {
self.uring.recv(
self.sock_fd,
self.packet.raw[recvd..Packet.size],
@ptrToInt(self),
);
self.bytes = recvd;
return .prop;
}
if (recvd != Packet.size) {
log.debug("corrupted", .{});
return .finish;
}
self.filePrep() catch |err| {
log.err("filePrep: {s}", .{@errorName(err)});
return .finish;
};
self.uring.recv(self.sock_fd, &self.packet.raw, @ptrToInt(self));
self.bytes = 0;
return .recv;
}
fn handleFileWrite(self: *SClient, res: i32) State {
self.bytes += @intCast(u64, res);
self.uring.recv(self.sock_fd, &self.packet.raw, @ptrToInt(self));
return .recv;
}
fn handleFileRecv(self: *SClient, res: i32) State {
if (self.bytes >= self.file_size)
return .finish;
self.uring.write(
self.file_fd.?,
self.packet.raw[0..@intCast(usize, res)],
self.bytes,
@ptrToInt(self),
);
return .write;
}
fn filePrep(self: *SClient) !void {
const pkt = &self.packet.pkt;
try pkt.check();
var buffer: [4096]u8 = undefined;
const fname = try fmt.bufPrint(&buffer, "{s}/{s}", .{
upload_dir,
pkt.getFname(),
});
self.file_fd = try os.open(fname, O.TRUNC | O.CREAT | O.WRONLY, 0o644);
self.file_size = pkt.getFsize();
log.info("File name: {s}: {}", .{ pkt.getFname(), pkt.getFsize() });
}
};
const SClientPool = heap.MemoryPoolExtra(SClient, .{});
const Server = struct {
allocator: mem.Allocator,
is_alive: bool,
sock_fd: ?os.fd_t,
clients: SClientPool,
uring: Uring,
fn init(allocator: mem.Allocator) !Server {
var clients = try SClientPool.initPreheated(allocator, server_clients_max);
errdefer clients.deinit();
return .{
.allocator = allocator,
.is_alive = false,
.sock_fd = null,
.clients = clients,
.uring = try Uring.init(server_queue_depth),
};
}
fn deinit(self: *Server) void {
self.uring.deinit();
self.clients.deinit();
}
fn run(self: *Server, host: []const u8, port: u16) !void {
var sserver = std.net.StreamServer.init(.{ .reuse_address = true });
defer sserver.deinit();
try sserver.listen(try net.Address.parseIp(host, port));
self.sock_fd = sserver.sockfd.?;
self.uring.accept(self.sock_fd.?, 1);
self.is_alive = true;
log.info("listening on: \"{s} {}\"", .{ host, port });
var cqes: [server_queue_depth]linux.io_uring_cqe = undefined;
while (self.is_alive) {
try self.uring.submitAndWait();
const sz = try self.uring.getCqes(&cqes);
if (sz == 0)
continue;
for (cqes[0..sz]) |*cqe| {
const err = cqe.err();
if (err != .SUCCESS) {
log.err("error: cqe.res: {s}", .{@tagName(err)});
continue;
}
if (cqe.user_data == 1) {
self.handleAccept(cqe.res);
} else {
const scl = @intToPtr(*SClient, cqe.user_data);
if (!scl.handle(cqe.res))
self.clients.destroy(scl);
}
}
}
}
fn stop(self: *Server) void {
if (!self.is_alive)
return;
if (self.sock_fd) |fd|
os.shutdown(fd, .both) catch {};
self.is_alive = false;
}
fn handleAccept(self: *Server, res: i32) void {
defer self.uring.accept(self.sock_fd.?, 1);
const sclient = self.clients.create() catch {
os.close(res);
return;
};
sclient.set(res, &self.uring);
log.info("new connecton: fd: {}", .{sclient.sock_fd});
}
};
const Client = struct {};
pub fn main() !void {
var srv = try Server.init(heap.page_allocator);
defer srv.deinit();
try srv.run("::", 8005);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment