Created
June 22, 2023 21:18
-
-
Save travisstaloch/24bf01b3d3f7e693cdcd104628f23c55 to your computer and use it in GitHub Desktop.
Simple reference counting in zig. Adapted from https://ravendb.net/articles/atomic-reference-counting-with-zig-code-samples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! adapted from https://ravendb.net/articles/atomic-reference-counting-with-zig-code-samples | |
//! compiled with zig version 0.11.0-dev.3771+128fd7dd0 | |
const std = @import("std"); | |
pub fn RefCounted(comptime T: type) type { | |
return struct { | |
const Self = @This(); | |
const InternalRef = packed union { item: packed struct { value: u44, is_error: bool, references: u19 }, raw: u64 }; | |
data: InternalRef, | |
version: std.atomic.Atomic(u32), | |
comptime { | |
if (@sizeOf(InternalRef) != @sizeOf(u64)) { | |
@compileError("Wrong size for InternalRef?"); | |
} | |
} | |
pub fn tryAcquire(self: *Self) !?*T { | |
while (true) { | |
var cur = self.data; | |
var original = cur; | |
if (cur.item.is_error) { | |
return @errorFromInt(@intCast(u16, cur.item.references)); | |
} | |
if (cur.item.value == 0) { | |
return null; | |
} | |
cur.item.references += 1; | |
if (cur.item.references == std.math.maxInt(u19)) { | |
return error.ReferenceCounterWouldOverflow; | |
} | |
if (@cmpxchgWeak(u64, &self.data.raw, original.raw, cur.raw, .Monotonic, .Monotonic) == null) { | |
var v = @intCast(u64, cur.item.value) << 4; | |
return @ptrFromInt(*T, v); | |
} | |
} | |
} | |
pub fn acquire(self: *Self) !*T { | |
while (true) { | |
const version = self.version.load(.Monotonic); | |
if (try self.tryAcquire()) |v| { | |
return v; | |
} | |
std.Thread.Futex.wait(&self.version, version); | |
} | |
} | |
pub fn init(self: *Self, val: *T) !void { | |
return self.set(val, null); | |
} | |
pub fn errored(self: *Self, code: anyerror) !void { | |
return self.set(null, code); | |
} | |
fn set(self: *Self, val: ?*T, code: ?anyerror) !void { | |
while (true) { | |
const cur = self.data; | |
if (cur.item.value != 0 or cur.item.is_error) { | |
return error.ValueAlreadySet; | |
} | |
var update: InternalRef = undefined; | |
if (code) |e| { | |
update = InternalRef{ .item = .{ .value = 0, .references = @intFromError(e), .is_error = true } }; | |
} else if (val) |v| { | |
var u44val = try pointerToU44(v); | |
update = InternalRef{ .item = .{ .value = u44val, .references = 1, .is_error = false } }; | |
} else { | |
return error.BothValueAndCodeAreEmpty; | |
} | |
if (@cmpxchgWeak(u64, &self.data.raw, cur.raw, update.raw, .Monotonic, .Monotonic) == null) { | |
_ = self.version.fetchAdd(1, .Release); | |
std.Thread.Futex.wake(&self.version, std.math.maxInt(u32)); | |
break; | |
} | |
} | |
} | |
fn pointerToU44(val: *T) !u44 { | |
const iVal = @intFromPtr(val); | |
if (iVal & 0b1111 != 0) { | |
return error.PointerValueNot16BytesAligned; | |
} | |
if (iVal >> 48 != 0) { | |
return error.PointerValueHigh16BitsArentCleared; | |
} | |
return @intCast(u44, iVal >> 4); | |
} | |
pub fn release(self: *Self, v: *T, comptime destroy: fn (*T, std.mem.Allocator) void, allocator: std.mem.Allocator) !void { | |
const val = try pointerToU44(v); | |
while (true) { | |
var cur = self.data; | |
const original = cur; | |
if (cur.item.is_error) { | |
return @errorFromInt(@intCast(u16, cur.item.value)); | |
} | |
if (cur.item.value != val) { | |
return error.RetunedValueDoesNotMatchStoredValue; | |
} | |
if (cur.item.references == 1) { | |
cur.raw = 0; | |
} else { | |
cur.item.references -= 1; | |
} | |
if (@cmpxchgWeak(u64, &self.data.raw, original.raw, cur.raw, .Monotonic, .Monotonic) == null) { | |
if (cur.raw == 0) { | |
destroy(v, allocator); | |
} | |
return; | |
} | |
} | |
} | |
}; | |
} | |
const Foo = struct { | |
l: ?*Foo = null, | |
r: ?*Foo = null, | |
val: usize = 0, | |
pub fn deinit(foo: *Foo, allocator: std.mem.Allocator) void { | |
allocator.destroy(foo); | |
} | |
}; | |
const RcFoo = RefCounted(Foo); | |
test { | |
const alloc = std.testing.allocator; | |
const a = try alloc.create(Foo); | |
defer a.deinit(alloc); | |
var rca = RcFoo{ .data = .{ .raw = 0 }, .version = std.atomic.Atomic(u32).init(0) }; | |
try rca.init(a); | |
try std.testing.expectEqual(a, try rca.acquire()); | |
try std.testing.expectEqual(a, try rca.acquire()); | |
try rca.release(a, Foo.deinit, alloc); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment