Skip to content

Instantly share code, notes, and snippets.

@brennanMKE
Created November 28, 2024 04:00
Show Gist options
  • Save brennanMKE/1f44fa8bb8a5d08312bd51d4e4b7ef85 to your computer and use it in GitHub Desktop.
Save brennanMKE/1f44fa8bb8a5d08312bd51d4e4b7ef85 to your computer and use it in GitHub Desktop.
Task Queue mimicks NSOperationQueue

TaskQueue

This code supports running a sequence of tasks with serial behavior or concurrently with max concurrent tasks. The tasks which are queued can even be constrained to a global actor like @MainActor or a custom global actor.

This API mimics NSOperationQueue which was available prior to Dispatch. It supports max concurrent tasks and other features. One feature is barrier tasks. Normally tasks are allowed to run concurrently. When a task is a barrier it will be run serially. All pending tasks will be allowed to complete and then a the lone barrier task will be run. Then other tasks can run as before. This is a useful behavior to use when multiple tasks do work and a barrier task can run after them to gather the outputs before running another set of tasks to do more processing.

import Foundation
typealias TaskItem = @Sendable () async -> Void
struct TaskContext {
let isBarrier: Bool
let task: TaskItem
}
final class TaskQueue: Sendable {
actor TaskCounter {
private var runningTasks = 0
private var atLimitContinuation: CheckedContinuation<Void, Never>?
private var barrierContinuation: CheckedContinuation<Void, Never>?
private let maxConcurrentTasks: Int
init(maxConcurrentTasks: Int) {
self.maxConcurrentTasks = maxConcurrentTasks
}
func increment() {
runningTasks += 1
}
func decrement() {
runningTasks -= 1
if let barrierContinuation, runningTasks == 0 {
barrierContinuation.resume()
self.barrierContinuation = nil
return
}
guard let atLimitContinuation, runningTasks < maxConcurrentTasks else {
return
}
atLimitContinuation.resume()
self.atLimitContinuation = nil
}
func waitIfNecessary(isBarrier: Bool = false) async {
if isBarrier, runningTasks > 0 {
await withCheckedContinuation { continuation in
self.barrierContinuation = continuation
}
return
}
guard runningTasks >= maxConcurrentTasks else {
return
}
await withCheckedContinuation { continuation in
self.atLimitContinuation = continuation
}
}
}
enum Behavior {
case serial
case concurrent(maxConcurrentTasks: Int)
var maxConcurrentTasks: Int {
switch self {
case .serial:
1
case .concurrent(let maxConcurrentTasks):
maxConcurrentTasks
}
}
}
private let taskContexts: AsyncStream<TaskContext>
private let dataContinuation: AsyncStream<TaskContext>.Continuation
private let counter: TaskCounter
init(behavior: Behavior = .serial, bufferingPolicy: AsyncStream<TaskContext>.Continuation.BufferingPolicy = .unbounded) {
self.counter = TaskCounter(maxConcurrentTasks: behavior.maxConcurrentTasks)
(taskContexts, dataContinuation) = AsyncStream.makeStream(bufferingPolicy: bufferingPolicy)
processTasks()
}
func receiveNext(isBarrier: Bool = false, task: @escaping TaskItem) {
let taskContext = TaskContext(isBarrier: isBarrier, task: task)
dataContinuation.yield(taskContext)
}
func finish() {
dataContinuation.finish()
}
func processTasks() {
Task {
for await taskContext in taskContexts {
await counter.waitIfNecessary(isBarrier: taskContext.isBarrier)
await counter.increment()
if taskContext.isBarrier {
// run serially
await taskContext.task()
await counter.decrement()
} else {
// run concurrently
Task {
await taskContext.task()
await counter.decrement()
}
}
}
}
}
}
import Foundation
import Testing
@testable import TaskQueueKit
struct TasksTests {
let delay = 1000
@globalActor public final actor TasksTestsActor {
public static let shared = TasksTestsActor()
}
actor Buffer {
private var continuation: CheckedContinuation<Void, Never>?
let max: Int
var numbers: [Int] = []
init(max: Int) {
self.max = max
}
func append(number: Int) {
numbers.append(number)
if let continuation, numbers.count == max {
continuation.resume()
self.continuation = nil
}
}
func untilDone() async {
guard numbers.count < max else {
return
}
await withCheckedContinuation { continuation in
self.continuation = continuation
}
}
}
var halfCoreCount: Int {
max(ProcessInfo.processInfo.activeProcessorCount, 2)
}
func validateOrder(numbers: [Int]) {
var count = 0
for i in 0..<numbers.count {
if i < numbers.count - 1 {
let current = numbers[i]
let next = numbers[i+1]
if current != next - 1 {
count += 1
}
}
}
if count > 0 {
print("💥 There are \(count) numbers out of order")
} else {
print("👌 Numbers are in order")
}
}
@Test func testTaskQueueSerial() async throws {
print("There are \(ProcessInfo.processInfo.activeProcessorCount) cores")
let max = ProcessInfo.processInfo.activeProcessorCount * 2
let numbers: [Int] = (1...max).map { $0 }
let taskQueue = TaskQueue(behavior: .serial, bufferingPolicy: .bufferingNewest(max))
let buffer = Buffer(max: max)
for number in numbers {
taskQueue.receiveNext {
let randomizedDelay = Int.random(in: (delay / 2)...delay)
try? await Task.sleep(for: .milliseconds(randomizedDelay))
print(number, terminator: " ")
await buffer.append(number: number)
}
}
await buffer.untilDone()
print("")
print("Input:", numbers)
let output = await buffer.numbers
validateOrder(numbers: output)
#expect(numbers == output)
print("Output:", output)
}
@Test func testTaskQueueConcurrentTwo() async throws {
print("There are \(ProcessInfo.processInfo.activeProcessorCount) cores")
let max = ProcessInfo.processInfo.activeProcessorCount * 2
let numbers: [Int] = (1...max).map { $0 }
let maxConcurrentTasks = 2
let taskQueue = TaskQueue(behavior: .concurrent(maxConcurrentTasks: maxConcurrentTasks), bufferingPolicy: .bufferingNewest(max))
let buffer = Buffer(max: max)
for number in numbers {
taskQueue.receiveNext {
let randomizedDelay = Int.random(in: (delay / 2)...delay)
try? await Task.sleep(for: .milliseconds(randomizedDelay))
print(number, terminator: " ")
await buffer.append(number: number)
}
}
await buffer.untilDone()
print("")
print("Input:", numbers)
let output = await buffer.numbers
validateOrder(numbers: output)
#expect(numbers == output.sorted())
print("Output:", output)
}
@Test func testTaskQueueConcurrentMany() async throws {
print("There are \(ProcessInfo.processInfo.activeProcessorCount) cores")
let max = ProcessInfo.processInfo.activeProcessorCount * 2
let numbers: [Int] = (1...max).map { $0 }
let taskQueue = TaskQueue(behavior: .concurrent(maxConcurrentTasks: halfCoreCount), bufferingPolicy: .bufferingNewest(max))
let buffer = Buffer(max: max)
for number in numbers {
taskQueue.receiveNext {
let randomizedDelay = Int.random(in: (delay / 2)...delay)
try? await Task.sleep(for: .milliseconds(randomizedDelay))
print(number, terminator: " ")
await buffer.append(number: number)
}
}
await buffer.untilDone()
print("")
print("Input:", numbers)
let output = await buffer.numbers
validateOrder(numbers: output)
#expect(numbers == output.sorted())
print("Output:", output)
}
@Test func testTaskQueueConcurrentManyOnMainActor() async throws {
print("There are \(ProcessInfo.processInfo.activeProcessorCount) cores")
let max = ProcessInfo.processInfo.activeProcessorCount * 2
let numbers: [Int] = (1...max).map { $0 }
let taskQueue = TaskQueue(behavior: .concurrent(maxConcurrentTasks: 3), bufferingPolicy: .bufferingNewest(max))
let buffer = Buffer(max: max)
for number in numbers {
taskQueue.receiveNext { @MainActor in
let randomizedDelay = Int.random(in: (delay / 2)...delay)
try? await Task.sleep(for: .milliseconds(randomizedDelay))
print(number, terminator: " ")
await buffer.append(number: number)
}
}
await buffer.untilDone()
print("")
print("Input:", numbers)
let output = await buffer.numbers
validateOrder(numbers: output)
#expect(numbers == output.sorted())
print("Output:", output)
}
@Test func testTaskQueueConcurrentManyOnTasksTestsActor() async throws {
print("There are \(ProcessInfo.processInfo.activeProcessorCount) cores")
let max = ProcessInfo.processInfo.activeProcessorCount * 2
let numbers: [Int] = (1...max).map { $0 }
let taskQueue = TaskQueue(behavior: .concurrent(maxConcurrentTasks: 3), bufferingPolicy: .bufferingNewest(max))
let buffer = Buffer(max: max)
for number in numbers {
taskQueue.receiveNext { @TasksTestsActor in
let randomizedDelay = Int.random(in: (delay / 2)...delay)
try? await Task.sleep(for: .milliseconds(randomizedDelay))
print(number, terminator: " ")
await buffer.append(number: number)
}
}
await buffer.untilDone()
print("")
print("Input:", numbers)
let output = await buffer.numbers
validateOrder(numbers: output)
#expect(numbers == output.sorted())
print("Output:", output)
}
@Test func testTaskQueueConcurrentManyOnTasksTestsActorForOdd() async throws {
print("There are \(ProcessInfo.processInfo.activeProcessorCount) cores")
let max = ProcessInfo.processInfo.activeProcessorCount * 2
let numbers: [Int] = (1...max).map { $0 }
let taskQueue = TaskQueue(behavior: .concurrent(maxConcurrentTasks: 3), bufferingPolicy: .bufferingNewest(max))
let buffer = Buffer(max: max)
for number in numbers {
if number % 2 == 0 {
taskQueue.receiveNext {
let randomizedDelay = Int.random(in: (delay / 2)...delay)
try? await Task.sleep(for: .milliseconds(randomizedDelay))
print(number, terminator: " ")
await buffer.append(number: number)
}
} else {
taskQueue.receiveNext { @TasksTestsActor in
let randomizedDelay = Int.random(in: (delay / 2)...delay)
try? await Task.sleep(for: .milliseconds(randomizedDelay))
print(number, terminator: " ")
await buffer.append(number: number)
}
}
}
await buffer.untilDone()
print("")
print("Input:", numbers)
let output = await buffer.numbers
validateOrder(numbers: output)
#expect(numbers == output.sorted())
print("Output:", output)
}
@Test func testTaskQueueWithBarriers() async throws {
print("There are \(ProcessInfo.processInfo.activeProcessorCount) cores")
let max = ProcessInfo.processInfo.activeProcessorCount * 2
let numbers: [Int] = (1...max).map { $0 }
let taskQueue = TaskQueue(behavior: .concurrent(maxConcurrentTasks: ProcessInfo.processInfo.activeProcessorCount), bufferingPolicy: .bufferingNewest(max))
let buffer = Buffer(max: max)
print("Queueing", terminator: " ")
for number in numbers {
let isBarrier = number % 7 == 0
print(number, terminator: " ")
taskQueue.receiveNext(isBarrier: isBarrier) {
let randomizedDelay = Int.random(in: (delay / 2)...delay)
try? await Task.sleep(for: .milliseconds(randomizedDelay))
if isBarrier {
print("[\(number)]", terminator: " ")
} else {
print(number, terminator: " ")
}
await buffer.append(number: number)
}
}
print("")
await buffer.untilDone()
print("")
print("Input:", numbers)
let output = await buffer.numbers
validateOrder(numbers: output)
#expect(numbers == output.sorted())
print("Output:", output)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment