Skip to content

Instantly share code, notes, and snippets.

@drewmccormack
Created July 26, 2024 11:49
Show Gist options
  • Save drewmccormack/b05b2f6ce9a976ae7f6f349a64c9f071 to your computer and use it in GitHub Desktop.
Save drewmccormack/b05b2f6ce9a976ae7f6f349a64c9f071 to your computer and use it in GitHub Desktop.
A macro that can be applied to an async func to generate a peer func (serial_<funcname>) that uses an internal queue to serialize execution. That is, it guarantees that only one copy of the func will run at a time. (Note that this can deadlock if it reenters.)
import Serial
import Foundation
// Example: Deliberately introducing shared data to show an interleaving race
// when @Serial is not used.
nonisolated(unsafe) var shared: [Int] = []
/// Shoulld returrn the numbers 1 to 10 if working properly
@Serial
func numbersToTen() async throws -> [Int] {
shared = []
for i in 1...10 {
shared.append(i)
try? await Task.sleep(for: .seconds(TimeInterval.random(in: 0.01...0.02)))
}
return shared
}
print("INTERLEAVING")
async let i: [Int] = try await numbersToTen()
async let j: [Int] = try await numbersToTen()
let a = try await [i, j]
print("\(a[0])")
print("\(a[1])")
print()
print("SERIALIZED")
async let k: [Int] = try await serial_numbersToTen()
async let l: [Int] = try await serial_numbersToTen()
let m = try await [k, l]
print("\(m[0])")
print("\(m[1])")
// Sample Output
//
// INTERLEAVING
// [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 8, 7, 9, 8, 10, 9]
// [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 8, 7, 9, 8, 10, 9, 10]
//
// SERIALIZED
// [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
// [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
//
import SwiftSyntax
import SwiftSyntaxMacros
import SwiftCompilerPlugin
public enum SerialError: Error, CustomStringConvertible {
case invalidInput
case invalidOutput
public var description: String {
switch self {
case .invalidInput:
return "Invalid input for macro expansion"
case .invalidOutput:
return "Invalid output from macro expansion"
}
}
}
public struct SerialMacro: PeerMacro {
public static func expansion<Context: MacroExpansionContext, Declaration: DeclSyntaxProtocol>(of node: AttributeSyntax, providingPeersOf declaration: Declaration, in context: Context) throws -> [DeclSyntax] {
guard let funcDecl = declaration.as(FunctionDeclSyntax.self) else {
throw SerialError.invalidInput
}
let originalName = funcDecl.name.text
let signature = funcDecl.signature.description
let effectSpecifiers = funcDecl.signature.effectSpecifiers?.description ?? ""
let body = funcDecl.body?.description ?? ""
let canThrow = effectSpecifiers.contains("throws")
let tryKeyword = canThrow ? "try" : ""
let throwsKeyword = canThrow ? "throws" : ""
let returnsVoid = funcDecl.signature.returnClause == nil
let returnClause = funcDecl.signature.returnClause?.description.trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
let returnClauseOrVoid = funcDecl.signature.returnClause?.description.trimmingCharacters(in: .whitespacesAndNewlines) ?? "-> Void"
let returnType = funcDecl.signature.returnClause?.type.description.trimmingCharacters(in: .whitespacesAndNewlines)
let errorType = canThrow ? "Swift.Error" : "Never"
let voidReturnBlockCall =
"""
\(tryKeyword) await item.block()
item.continuation.resume()
"""
let nonvoidReturnBlockCall =
"""
let returned = \(tryKeyword) await item.block()
item.continuation.resume(returning: returned)
"""
let blockCall = canThrow ?
"""
do {
\(returnsVoid ? voidReturnBlockCall : nonvoidReturnBlockCall)
} catch {
item.continuation.resume(throwing: error)
}
""" :
"""
\(returnsVoid ? voidReturnBlockCall : nonvoidReturnBlockCall)
"""
let continuationFuncName = canThrow ? "withCheckedThrowingContinuation" : "withCheckedContinuation"
let newFuncSyntax: DeclSyntax =
"""
func serial_\(raw: originalName)\(raw: signature) {
return \(raw: tryKeyword) await AsyncSerialQueue.serialized(
\(raw: body)
)
class AsyncSerialQueue {
private typealias ItemBlock = () async \(raw: throwsKeyword) \(raw: returnClauseOrVoid)
private struct Item {
let block: ItemBlock
let continuation: CheckedContinuation<\(raw: returnType ?? "Void"), \(raw: errorType)>
}
private var stream: AsyncStream<Item>
private var streamContinuation: AsyncStream<Item>.Continuation
init() {
var cont: AsyncStream<Item>.Continuation? = nil
let stream: AsyncStream<Item> = AsyncStream { continuation in
cont = continuation
}
self.stream = stream
self.streamContinuation = cont!
Task {
for await item in stream {
\(raw: blockCall)
}
}
}
func enqueue(_ block: @escaping ItemBlock) async \(raw: throwsKeyword) \(raw: returnClause) {
\(raw: tryKeyword) await \(raw: continuationFuncName) { continuation in
let item = Item(block: block, continuation: continuation)
streamContinuation.yield(item)
}
}
static let serializedQueue = AsyncSerialQueue()
/// Passing a function to this will execute it in the queue atomically, and in order.
static func serialized(_ wrapped: @escaping ItemBlock) async \(raw: throwsKeyword) \(raw: returnClause) {
\(raw: tryKeyword) await serializedQueue.enqueue(wrapped)
}
}
}
"""
return [newFuncSyntax]
}
}
@main
struct SerialPlugin: CompilerPlugin {
let providingMacros: [Macro.Type] = [
SerialMacro.self,
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment