Skip to content

Instantly share code, notes, and snippets.

@thomsmed
Created August 18, 2024 14:56
Show Gist options
  • Save thomsmed/7d869d8ee5bd762326bf27c888d85abc to your computer and use it in GitHub Desktop.
Save thomsmed/7d869d8ee5bd762326bf27c888d85abc to your computer and use it in GitHub Desktop.
An AsyncSemaphore suited for use in asynchronous contexts. Heavily inspired by https://github.com/groue/Semaphore.
//
// AsyncSemaphore.swift
//
import Foundation
/// An AsyncSemaphore suited for use in asynchronous contexts. Heavily inspired by https://github.com/groue/Semaphore.
public final class AsyncSemaphore: @unchecked Sendable {
final class QueuedContinuation: @unchecked Sendable {
enum State {
case pending
case waiting(CheckedContinuation<Void, any Error>)
case canceled
}
var state: State
init(state: State) {
self.state = state
}
}
private let recursiveLock: NSRecursiveLock = NSRecursiveLock()
private var queuedContinuations: [QueuedContinuation] = []
private var count: Int
public init(count: Int) {
precondition(count >= 0, "The number of protected resources should start off as zero or more")
self.count = count
}
deinit {
precondition(count > 0, "There are still tasks waiting for a protected resource")
}
private func lock() {
// The compiler does not allow calls to NSRecursiveLock.lock() in asynchronous contexts.
// We circumvent it by calling this wrapper method instead.
// This is dangerous, but we know what we're doing.
recursiveLock.lock()
}
private func unlock() {
// The compiler does not allow calls to NSRecursiveLock.unlock() in asynchronous contexts.
// We circumvent it by calling this wrapper method instead.
// This is dangerous, but we know what we're doing.
recursiveLock.unlock()
}
}
// MARK: Public Methods
public extension AsyncSemaphore {
func wait() async {
lock()
count -= 1
if count >= 0 {
// We got a resource. Unlock and return.
return unlock()
}
do {
try await withCheckedThrowingContinuation { continuation in
defer { unlock() }
let queuedContinuation = QueuedContinuation(state: .waiting(continuation))
queuedContinuations.insert(queuedContinuation, at: 0)
}
} catch {
assertionFailure("This should never happen")
}
}
func waitUnlessCancelled() async throws {
try Task.checkCancellation()
lock()
count -= 1
if count >= 0 {
// We got a resource. Unlock and return.
return unlock()
}
let queuedContinuation = QueuedContinuation(state: .pending)
try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { continuation in
defer { unlock() }
switch queuedContinuation.state {
case .pending:
queuedContinuation.state = .waiting(continuation)
queuedContinuations.insert(queuedContinuation, at: 0)
case .waiting:
assertionFailure("Unexpected QueuedContinuation.State in this context")
case .canceled:
// The calling task has been cancelled before we were able to put it in queue for the protected resource(s).
continuation.resume(throwing: CancellationError())
}
}
} onCancel: {
// onCancel might be called right away if the calling Task has already been marked as cancelled.
lock()
defer { unlock() }
count += 1
if let index = queuedContinuations.firstIndex(where: { $0 === queuedContinuation }) {
queuedContinuations.remove(at: index)
}
switch queuedContinuation.state {
case .pending:
// The calling task has been cancelled before we were able to put it in queue for the protected resource(s).
queuedContinuation.state = .canceled
case .waiting(let continuation):
continuation.resume(throwing: CancellationError())
case .canceled:
assertionFailure("Unexpected QueuedContinuation.State in this context")
}
}
}
func signal() {
lock()
defer { unlock() }
count += 1
switch queuedContinuations.popLast()?.state {
case .pending:
assertionFailure("Unexpected QueuedContinuation.State in this context")
case .waiting(let continuation):
continuation.resume()
case .canceled:
assertionFailure("Unexpected QueuedContinuation.State in this context")
case nil:
// No queued continuations.
break
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment