Skip to content

Instantly share code, notes, and snippets.

@lithdew
Created March 1, 2021 11:56
Show Gist options
  • Save lithdew/d186960990577f2546de2c97289fb104 to your computer and use it in GitHub Desktop.
Save lithdew/d186960990577f2546de2c97289fb104 to your computer and use it in GitHub Desktop.
zig: mpsc queue
const std = @import("std");
const os = std.os;
const mem = std.mem;
const testing = std.testing;
pub fn Channel(comptime T: type) type {
return struct {
const Self = @This();
pub const Node = struct {
next: ?*Node = null,
value: T,
};
back: *Node align(64),
count: usize,
front: Node align(64),
pub fn init(self: *Self) void {
@atomicStore(?*Node, &self.front.next, null, .Monotonic);
@atomicStore(*Node, &self.back, &self.front, .Monotonic);
}
pub fn tryPush(self: *Self, src: *Node) callconv(.Inline) void {
@atomicStore(?*Node, &src.next, null, .Release);
const old_back = @atomicRmw(*Node, &self.back, .Xchg, src, .AcqRel);
@atomicStore(?*Node, &old_back.next, src, .Release);
}
pub fn tryPushBatch(self: *Self, first: *Node, last: *Node) callconv(.Inline) void {
@atomicStore(?*Node, &last.next, null, .Release);
const old_back = @atomicRmw(*Node, &self.back, .Xchg, last, .AcqRel);
@atomicStore(?*Node, &old_back.next, first, .Release);
}
pub fn tryRecv(self: *Self) callconv(.Inline) ?*Node {
var first = @atomicLoad(?*Node, &self.front.next, .Acquire) orelse return null;
if (@atomicLoad(?*Node, &first.next, .Acquire)) |next| {
@atomicStore(?*Node, &self.front.next, next, .Monotonic);
return first;
}
var last = @atomicLoad(*Node, &self.back, .Acquire);
if (first != last) return null;
@atomicStore(?*Node, &self.front.next, null, .Monotonic);
if (@cmpxchgStrong(*Node, &self.back, last, &self.front, .AcqRel, .Acquire) == null) {
return first;
}
var maybe_next = @atomicLoad(?*Node, &first.next, .Acquire);
while (maybe_next == null) : (os.sched_yield() catch {}) {
maybe_next = @atomicLoad(?*Node, &first.next, .Acquire);
}
@atomicStore(?*Node, &self.front.next, maybe_next, .Monotonic);
return first;
}
pub fn tryRecvBatch(self: *Self, b_first: **Node, b_last: **Node) callconv(.Inline) usize {
var front = @atomicLoad(?*Node, &self.front.next, .Acquire) orelse return 0;
b_first.* = front;
var maybe_next = @atomicLoad(?*Node, &front.next, .Acquire);
var result: usize = 0;
while (maybe_next) |next| {
result += 1;
b_last.* = front;
front = next;
maybe_next = @atomicLoad(?*Node, &next.next, .Acquire);
}
var last = @atomicLoad(*Node, &self.back, .Acquire);
if (front != last) {
@atomicStore(?*Node, &self.front.next, front, .Release);
return result;
}
@atomicStore(?*Node, &self.front.next, null, .Monotonic);
if (@cmpxchgStrong(*Node, &self.back, last, &self.front, .AcqRel, .Acquire) == null) {
result += 1;
b_last.* = front;
return result;
}
maybe_next = @atomicLoad(?*Node, &front.next, .Acquire);
while (maybe_next == null) : (os.sched_yield() catch {}) {
maybe_next = @atomicLoad(?*Node, &front.next, .Acquire);
}
result += 1;
@atomicStore(?*Node, &self.front.next, maybe_next, .Monotonic);
b_last.* = front;
return result;
}
};
}
const NUM_ITEMS = 15_000_000;
const NUM_PRODUCERS = 15;
const TestChannel = Channel(u64);
const Context = struct {
allocator: *mem.Allocator,
chan: *TestChannel,
};
fn runProducer(self: Context) !void {
var i: usize = 0;
while (i < NUM_ITEMS / NUM_PRODUCERS) : (i += 1) {
const node = try self.allocator.create(TestChannel.Node);
node.* = .{ .value = @intCast(u64, i) };
self.chan.tryPush(node);
}
}
fn runConsumerBatch(self: Context) !void {
var first: *TestChannel.Node = undefined;
var last: *TestChannel.Node = undefined;
var i: usize = 0;
while (i < NUM_ITEMS) {
var count = self.chan.tryRecvBatch(&first, &last);
i += count;
while (count > 0) : (count -= 1) {
const next = first.next;
self.allocator.destroy(first);
first = next.?;
}
}
}
fn runConsumer(self: Context) !void {
var i: usize = 0;
while (i < NUM_ITEMS) : (i += 1) {
const node = while (true) {
if (self.chan.tryRecv()) |node| {
break node;
}
} else unreachable;
self.allocator.destroy(node);
}
}
pub fn main() !void {
const allocator = std.heap.c_allocator;
var chan: TestChannel = undefined;
chan.init();
var timer = try std.time.Timer.start();
var consumer = try std.Thread.spawn(Context{
.allocator = allocator,
.chan = &chan,
}, runConsumerBatch);
var producers: [NUM_PRODUCERS]*std.Thread = undefined;
for (producers) |*producer| {
producer.* = try std.Thread.spawn(Context{
.allocator = allocator,
.chan = &chan,
}, runProducer);
}
for (producers) |producer| producer.wait();
consumer.wait();
std.debug.print("Done! Took {d}.\n", .{timer.read() / std.time.ns_per_ms});
}
test {
testing.refAllDecls(Channel(u64));
std.debug.print("\n", .{});
var chan: Channel(u64) = undefined;
chan.init();
var A = Channel(u64).Node{ .value = 0 };
var B = Channel(u64).Node{ .value = 1 };
var C = Channel(u64).Node{ .value = 2 };
chan.tryPush(&A);
chan.tryPush(&B);
chan.tryPush(&C);
var first: *Channel(u64).Node = undefined;
var last: *Channel(u64).Node = undefined;
var count = chan.tryRecvBatch(&first, &last);
testing.expect(count == 3);
while (count > 0) : (count -= 1) {
std.debug.print("GOT: {}\n", .{first.value});
first = first.next orelse break;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment