Skip to content

Instantly share code, notes, and snippets.

@travisstaloch
Created June 22, 2023 21:18
Show Gist options
  • Save travisstaloch/24bf01b3d3f7e693cdcd104628f23c55 to your computer and use it in GitHub Desktop.
Save travisstaloch/24bf01b3d3f7e693cdcd104628f23c55 to your computer and use it in GitHub Desktop.
//! adapted from https://ravendb.net/articles/atomic-reference-counting-with-zig-code-samples
//! compiled with zig version 0.11.0-dev.3771+128fd7dd0
const std = @import("std");
pub fn RefCounted(comptime T: type) type {
return struct {
const Self = @This();
const InternalRef = packed union { item: packed struct { value: u44, is_error: bool, references: u19 }, raw: u64 };
data: InternalRef,
version: std.atomic.Atomic(u32),
comptime {
if (@sizeOf(InternalRef) != @sizeOf(u64)) {
@compileError("Wrong size for InternalRef?");
}
}
pub fn tryAcquire(self: *Self) !?*T {
while (true) {
var cur = self.data;
var original = cur;
if (cur.item.is_error) {
return @errorFromInt(@intCast(u16, cur.item.references));
}
if (cur.item.value == 0) {
return null;
}
cur.item.references += 1;
if (cur.item.references == std.math.maxInt(u19)) {
return error.ReferenceCounterWouldOverflow;
}
if (@cmpxchgWeak(u64, &self.data.raw, original.raw, cur.raw, .Monotonic, .Monotonic) == null) {
var v = @intCast(u64, cur.item.value) << 4;
return @ptrFromInt(*T, v);
}
}
}
pub fn acquire(self: *Self) !*T {
while (true) {
const version = self.version.load(.Monotonic);
if (try self.tryAcquire()) |v| {
return v;
}
std.Thread.Futex.wait(&self.version, version);
}
}
pub fn init(self: *Self, val: *T) !void {
return self.set(val, null);
}
pub fn errored(self: *Self, code: anyerror) !void {
return self.set(null, code);
}
fn set(self: *Self, val: ?*T, code: ?anyerror) !void {
while (true) {
const cur = self.data;
if (cur.item.value != 0 or cur.item.is_error) {
return error.ValueAlreadySet;
}
var update: InternalRef = undefined;
if (code) |e| {
update = InternalRef{ .item = .{ .value = 0, .references = @intFromError(e), .is_error = true } };
} else if (val) |v| {
var u44val = try pointerToU44(v);
update = InternalRef{ .item = .{ .value = u44val, .references = 1, .is_error = false } };
} else {
return error.BothValueAndCodeAreEmpty;
}
if (@cmpxchgWeak(u64, &self.data.raw, cur.raw, update.raw, .Monotonic, .Monotonic) == null) {
_ = self.version.fetchAdd(1, .Release);
std.Thread.Futex.wake(&self.version, std.math.maxInt(u32));
break;
}
}
}
fn pointerToU44(val: *T) !u44 {
const iVal = @intFromPtr(val);
if (iVal & 0b1111 != 0) {
return error.PointerValueNot16BytesAligned;
}
if (iVal >> 48 != 0) {
return error.PointerValueHigh16BitsArentCleared;
}
return @intCast(u44, iVal >> 4);
}
pub fn release(self: *Self, v: *T, comptime destroy: fn (*T, std.mem.Allocator) void, allocator: std.mem.Allocator) !void {
const val = try pointerToU44(v);
while (true) {
var cur = self.data;
const original = cur;
if (cur.item.is_error) {
return @errorFromInt(@intCast(u16, cur.item.value));
}
if (cur.item.value != val) {
return error.RetunedValueDoesNotMatchStoredValue;
}
if (cur.item.references == 1) {
cur.raw = 0;
} else {
cur.item.references -= 1;
}
if (@cmpxchgWeak(u64, &self.data.raw, original.raw, cur.raw, .Monotonic, .Monotonic) == null) {
if (cur.raw == 0) {
destroy(v, allocator);
}
return;
}
}
}
};
}
const Foo = struct {
l: ?*Foo = null,
r: ?*Foo = null,
val: usize = 0,
pub fn deinit(foo: *Foo, allocator: std.mem.Allocator) void {
allocator.destroy(foo);
}
};
const RcFoo = RefCounted(Foo);
test {
const alloc = std.testing.allocator;
const a = try alloc.create(Foo);
defer a.deinit(alloc);
var rca = RcFoo{ .data = .{ .raw = 0 }, .version = std.atomic.Atomic(u32).init(0) };
try rca.init(a);
try std.testing.expectEqual(a, try rca.acquire());
try std.testing.expectEqual(a, try rca.acquire());
try rca.release(a, Foo.deinit, alloc);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment