Last active
May 23, 2021 22:24
-
-
Save lithdew/2802fa5cb398ccca7d77a899a4b4441f to your computer and use it in GitHub Desktop.
zig: cancellation token
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const os = std.os; | |
const builtin = std.builtin; | |
const testing = std.testing; | |
const assert = std.debug.assert; | |
/// A simple cancellation framework for blocking/non-blocking tasks. | |
pub const Cancellation = struct { | |
/// A callback that handles the cancellation of some blocking/non-blocking task. | |
pub const Callback = struct { | |
next: ?*Cancellation.Callback = null, | |
prev_next: ?*?*Cancellation.Callback = null, | |
onCancel: fn (*Cancellation.Callback) void, | |
}; | |
/// A cancellation source which is used to derive cancellation tokens from. Keeps | |
/// track of all state that is necessary to cancel blocking/non-blocking tasks. A | |
/// cancellation source may only be requested to be cancelled once. | |
/// | |
/// All methods are thread-safe. | |
pub const Source = struct { | |
head: ?*Cancellation.Callback = null, | |
lock: std.Thread.Mutex = .{}, | |
cancelled: bool = false, | |
/// Derive a new cancellation token. | |
pub fn token(self: *Cancellation.Source) Cancellation.Token { | |
return .{ .source = self }; | |
} | |
/// It returns true if a cancellation has been requested. | |
pub fn isCancelled(self: *const Cancellation.Source) bool { | |
return @atomicLoad(bool, &self.cancelled, .Acquire); | |
} | |
/// Attempts to register a callback that will be invoked when a cancellation is requested. If | |
/// a cancellation had already previously been requested, the callback is immediately invoked | |
/// and error.Cancelled is returned. | |
pub fn tryAddCallback(self: *Cancellation.Source, callback: *Cancellation.Callback) !void { | |
const held = self.lock.acquire(); | |
defer held.release(); | |
if (self.isCancelled()) { | |
callback.onCancel(callback); | |
return error.Cancelled; | |
} | |
// Prepend callback to the head of the list. | |
if (self.head) |head| { | |
head.prev_next = &callback.next; | |
} | |
callback.next = self.head; | |
callback.prev_next = &self.head; | |
self.head = callback; | |
} | |
/// Cleanup a callback that had previously been registered, suich that it won't be called when | |
/// a request for cancellation happens. If the callback had already been called due to a prior | |
/// request for cancellation, this method does nothing. | |
pub fn removeCallback(self: *Cancellation.Source, callback: *Cancellation.Callback) void { | |
const held = self.lock.acquire(); | |
defer held.release(); | |
if (callback.prev_next) |prev_next| { | |
prev_next.* = callback.next; | |
if (callback.next) |next| { | |
next.prev_next = callback.prev_next; | |
} | |
} | |
} | |
/// Request for cancellation. Multiple requests may be made concurrently, though only the first | |
/// request will be taken into consideration. The thread that made the request invokes all | |
/// registered cancellation callbacks. | |
pub fn cancel(self: *Cancellation.Source) void { | |
if (@atomicRmw(bool, &self.cancelled, .Xchg, true, .AcqRel)) { | |
return; | |
} | |
var head = xchg: { | |
const held = self.lock.acquire(); | |
defer held.release(); | |
const head = self.head; | |
while (self.head) |callback| : (self.head = callback.next) { | |
callback.prev_next = null; | |
} | |
break :xchg head; | |
}; | |
while (head) |callback| : (head = callback.next) { | |
callback.onCancel(callback); | |
} | |
} | |
}; | |
/// A cancellation token derived from a cancellation source. | |
pub const Token = struct { | |
source: *Cancellation.Source, | |
/// It returns true if a cancellation has been requested. | |
pub fn isCancelled(self: Cancellation.Token) bool { | |
return self.source.isCancelled(); | |
} | |
/// Attempts to register a callback that will be invoked when a cancellation is requested. If | |
/// a cancellation had already previously been requested, the callback is immediately invoked | |
/// and error.Cancelled is returned. | |
pub fn tryAddCallback(self: Cancellation.Token, callback: *Cancellation.Callback) !void { | |
return self.source.tryAddCallback(callback); | |
} | |
/// Cleanup a callback that had previously been registered, suich that it won't be called when | |
/// a request for cancellation happens. If the callback had already been called due to a prior | |
/// request for cancellation, this method does nothing. | |
pub fn removeCallback(self: Cancellation.Token, callback: *Cancellation.Callback) void { | |
return self.source.removeCallback(callback); | |
} | |
}; | |
}; | |
test { | |
testing.refAllDecls(@This()); | |
} | |
test "cancellation: cancel blocking tasks" { | |
if (builtin.single_threaded) { | |
return error.SkipZigTest; | |
} | |
const Test = struct { | |
fn run(token: Cancellation.Token) !void { | |
var event: std.Thread.StaticResetEvent = .{}; | |
var callback: struct { | |
state: Cancellation.Callback = .{ .onCancel = onCancel }, | |
ref: *std.Thread.StaticResetEvent, | |
pub fn onCancel(callback: *Cancellation.Callback) void { | |
@fieldParentPtr(@This(), "state", callback).ref.set(); | |
} | |
} = .{ .ref = &event }; | |
// Some threads may race with the request for cancellation in the test, and | |
// so we handle errors coming from Cancellation.Source.tryAddCallback(). | |
if (token.tryAddCallback(&callback.state)) { | |
defer token.removeCallback(&callback.state); | |
event.wait(); | |
try testing.expect(token.isCancelled()); | |
} else |err| { | |
try testing.expect(err == error.Cancelled); | |
} | |
} | |
}; | |
var source: Cancellation.Source = .{}; | |
const t1 = try std.Thread.spawn(Test.run, source.token()); | |
const t2 = try std.Thread.spawn(Test.run, source.token()); | |
const t3 = try std.Thread.spawn(Test.run, source.token()); | |
source.cancel(); | |
t1.wait(); | |
t2.wait(); | |
t3.wait(); | |
} | |
test "cancellation: cancel non-blocking tasks" { | |
const Test = struct { | |
fn run(token: Cancellation.Token) !void { | |
var callback: struct { | |
state: Cancellation.Callback = .{ .onCancel = onCancel }, | |
ref: anyframe = undefined, | |
pub fn onCancel(callback: *Cancellation.Callback) void { | |
resume @fieldParentPtr(@This(), "state", callback).ref; | |
} | |
} = .{}; | |
try token.tryAddCallback(&callback.state); | |
defer token.removeCallback(&callback.state); | |
suspend { | |
callback.ref = @frame(); | |
} | |
try testing.expect(token.isCancelled()); | |
} | |
}; | |
var source: Cancellation.Source = .{}; | |
var t1 = async Test.run(source.token()); | |
var t2 = async Test.run(source.token()); | |
var t3 = async Test.run(source.token()); | |
source.cancel(); | |
try nosuspend await t1; | |
try nosuspend await t2; | |
try nosuspend await t3; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment