Created
August 1, 2024 16:36
-
-
Save fruitcoder/d1c3947e46671b8b1ac3492df42e8a3b to your computer and use it in GitHub Desktop.
Trying to make a timeout work for Swift 6
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
final actor ActorIsolatedExpectation<Value: Sendable & Equatable> { | |
/// The actor's expected value | |
let expectedValue: Value | |
/// The actor's current value (initially `nil`) | |
var currentValue: Value? | |
let id = UUID() | |
/// Initializes the actor with an expexted value. | |
/// - Parameter expectedValue: The value to compare the current value with. | |
init(_ expectedValue: Value) { | |
self.expectedValue = expectedValue | |
} | |
/// This function will suspend until the current value matches the expected value or the timeout is reached. | |
/// | |
/// Calling this function while a previous `expect` is still waiting will fail the test. | |
func expect( | |
timeout: Double = 1.0, | |
file: StaticString = #filePath, | |
line: UInt = #line | |
) async throws { | |
if currentExpectationGroup != nil { | |
XCTFail("\(self.id): Calling `expect` again before the previous finished or timed out is a user error.", file: file, line: line) | |
} | |
/// Expected value was set before expecting it | |
if currentValue == expectedValue { | |
return | |
} | |
return await withThrowingTaskGroup(of: Void.self) { group in | |
let deadline = Date(timeIntervalSinceNow: timeout) | |
self.currentExpectationGroup = group | |
group.addTask { | |
return try await withTaskCancellationHandler { | |
try await withUnsafeThrowingContinuation { continuation in | |
Task { | |
/// Expected value was set while waiting for adding the task to the task group | |
if await self.currentValue == self.expectedValue { | |
continuation.resume() | |
return | |
} | |
await self.setExpectationContinuation(continuation) | |
// actor reentrancy | |
if await self.currentValue == self.expectedValue { | |
continuation.resume() | |
} | |
} | |
} | |
} onCancel: { | |
Task { | |
await self.expectationContinuation?.resume(throwing: CancellationError()) | |
await self.setExpectationContinuation(nil) | |
} | |
} | |
} | |
group.addTask { | |
let interval = deadline.timeIntervalSinceNow | |
if interval > 0 { | |
try? await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000)) | |
} | |
guard !Task.isCancelled else { return } | |
XCTFail( | |
"\(self.id): Expectation timeout. Current value \(await String(describing: self.currentValue)) wasn't set to \(self.expectedValue) in time.", | |
file: file, | |
line: line | |
) | |
} | |
// Compiler error Pattern that the region based isolation checker does not understand how to check. Please file a bug | |
try? await group.next()! | |
group.cancelAll() | |
} | |
} | |
/// Updates the current value. | |
/// | |
/// If the `newValue` matches the expected value the caller awaiting `expect` will unsuspend. | |
/// - Parameter newValue: The new value. | |
func setValue(_ newValue: Value) { | |
self.currentValue = newValue | |
if currentValue == expectedValue { | |
expectationContinuation?.resume() | |
expectationContinuation = nil | |
} | |
} | |
private func setExpectationContinuation(_ c: UnsafeContinuation<Void, Swift.Error>?) { | |
expectationContinuation = c | |
} | |
private var expectationContinuation: UnsafeContinuation<Void, Swift.Error>? | |
private var currentExpectationGroup: ThrowingTaskGroup<Void, Swift.Error>? | |
} | |
extension AsyncSequence where Self.Element: Equatable & Sendable { | |
/// Returns when the asynchronous sequence first emits the given element. | |
/// | |
/// - Parameter search: The element to find in the asynchronous sequence. | |
func first( | |
_ search: Self.Element, | |
timeout: TimeInterval = 1.0, | |
file: StaticString = #filePath, | |
line: UInt = #line | |
) async { | |
do { | |
try await _until(where: { $0 == search }, timeout: timeout) | |
} catch is CancellationError { | |
XCTFail( | |
"`until(_:timeout:)` isn't cancellable. You should wait for the sequence to produce a \(search) or the operation to time out", | |
file: file, | |
line: line | |
) | |
} catch is TimeoutError { | |
XCTFail( | |
"`until(_:timeout:)` timed out before the sequence returned a \(search).", | |
file: file, | |
line: line | |
) | |
} catch { | |
XCTFail( | |
"`until(_:timeout:)` encountered an unwkown error \(error.localizedDescription)", | |
file: file, | |
line: line | |
) | |
} | |
} | |
// catches to XCTFail | |
func until( | |
where predicate: @Sendable @escaping (Self.Element) async throws -> Bool, | |
timeout: TimeInterval = 1.0, | |
file: StaticString = #filePath, | |
line: UInt = #line | |
) async { | |
do { | |
try await _until(where: predicate, timeout: timeout) | |
} catch is CancellationError { | |
XCTFail( | |
"`until` isn't cancellable. You should wait for predicate to return `true` or the operation to time out", | |
file: file, | |
line: line | |
) | |
} catch is TimeoutError { | |
XCTFail( | |
"`until` timed out before the predicate returned `true`.", | |
file: file, | |
line: line | |
) | |
} catch { | |
XCTFail( | |
"`until` encountered an unwkown error \(error.localizedDescription)", | |
file: file, | |
line: line | |
) | |
} | |
} | |
// can throw | |
private func _until(where predicate: @Sendable @escaping (Self.Element) async throws -> Bool, timeout: TimeInterval) async throws { | |
let sendableSelf = UnsafeTransfer(value: self) | |
return try await withTimeout(1.0) { | |
_ = try await sendableSelf.value.first(where: predicate) | |
} | |
} | |
} | |
struct UnsafeTransfer<Value>: @unchecked Sendable { | |
var value: Value | |
} | |
func firstOf<R: Sendable>( | |
_ f1: @Sendable @escaping () async throws -> R, | |
or f2: @Sendable @escaping () async throws -> R | |
) async throws -> R { | |
try Task.checkCancellation() | |
return try await withThrowingTaskGroup(of: R.self) { group in | |
try Task.checkCancellation() | |
guard group.addTaskUnlessCancelled(operation: { try await f1() }) else { | |
throw CancellationError() | |
} | |
guard group.addTaskUnlessCancelled(operation: { try await f2() }) else { | |
group.cancelAll() | |
throw CancellationError() | |
} | |
guard let first = try await group.next() else { | |
fatalError() | |
} | |
group.cancelAll() | |
return first | |
} | |
} | |
func withTimeout<R: Sendable>( | |
_ seconds: Double, | |
_ work: @Sendable @escaping () async throws -> R | |
) async throws -> R { | |
try await firstOf { | |
try await work() | |
} or: { | |
try? await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000)) | |
throw TimeoutError() | |
} | |
} | |
struct TimeoutError: Error {} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment