Created
June 12, 2026 16:08
-
-
Save fbeeper/4ebdd6b5b2eaa2d5cb1b74e36e01a941 to your computer and use it in GitHub Desktop.
Quick and dirty draft of a gated tool and its use with DynamicProfile (an AgentKitten spin off)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import Foundation | |
| import FoundationModels | |
| class ChatAgent { | |
| let toolGate: ToolApprovalGate | |
| let session: LanguageModelSession | |
| public init( | |
| model: some LanguageModel, | |
| onApprovalRequired: @escaping (PendingToolCall) async -> Void, | |
| ) { | |
| toolGate = ToolApprovalGate(onApprovalRequired: onApprovalRequired) | |
| let profile = ChatAgentProfile( | |
| model: model, | |
| toolGate: toolGate, | |
| ) | |
| self.session = LanguageModelSession( | |
| profile: profile | |
| ) | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| struct ChatAgentProfile<LM: LanguageModel>: LanguageModelSession.DynamicProfile { | |
| var model: LM | |
| let toolGate: ToolApprovalGate | |
| var body: some DynamicProfile { | |
| Profile { | |
| Instructions("You are a helpful assistant that gives extremely short answers.") | |
| Instructions("If a tool call is denied inform the user clearly where it failed but still report on any partial resuts you have that are meaningful.") | |
| // The gated tool will trigger an onApprovalRequired when it is run and will wait for approval/denial given to the toolGate. | |
| GatedTool(RandomNumberTool(), gate: toolGate) | |
| } | |
| .model(model) | |
| // This was my original take but a thrown error here ends generation. | |
| // .onToolCall { call in | |
| // try await self.onToolCall(call) | |
| // } | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import Foundation | |
| import FoundationModels | |
| public struct GatedTool<T: Tool>: Tool { | |
| public typealias Arguments = T.Arguments | |
| public typealias Output = GatedOutput<T.Output> | |
| private let _name: @Sendable () -> String | |
| private let _description: @Sendable () -> String | |
| private let _parameters: @Sendable () -> GenerationSchema | |
| private let _call: @Sendable (Arguments) async throws -> T.Output | |
| let gate: ToolApprovalGate | |
| public init(_ base: T, gate: ToolApprovalGate) { | |
| self._name = { | |
| base.name | |
| } | |
| self._description = { | |
| base.description | |
| } | |
| self._parameters = { | |
| base.parameters | |
| } | |
| self._call = { args in | |
| try await base.call(arguments: args) | |
| } | |
| self.gate = gate | |
| } | |
| public var name: String { | |
| _name() | |
| } | |
| public var description: String { | |
| _description() | |
| } | |
| public var parameters: GenerationSchema { | |
| _parameters() | |
| } | |
| public func call(arguments: Arguments) async throws -> Output { | |
| let toolCallID = UUID().uuidString | |
| let pendingToolCall = PendingToolCall(id: toolCallID, name: name) | |
| try await gate.register(call: pendingToolCall) | |
| return try await withTaskCancellationHandler( | |
| operation: { | |
| let approvalDecision = try await gate.waitForResolution(callID: toolCallID) | |
| try Task.checkCancellation() | |
| switch approvalDecision { | |
| case .approved: | |
| let value = try await _call(arguments) | |
| return .success(value) | |
| case .denied(let reason): | |
| return .failure(error: reason) | |
| } | |
| }, | |
| onCancel: { | |
| Task { | |
| await gate.cancel(callID: toolCallID) | |
| } | |
| }, | |
| ) | |
| } | |
| } | |
| public enum GatedOutput<SuccessOutput: PromptRepresentable>: PromptRepresentable { | |
| case success(SuccessOutput) | |
| case failure(error: String) | |
| nonisolated public var promptRepresentation: Prompt { | |
| switch self { | |
| case .success(let base): | |
| base.promptRepresentation | |
| case .failure(let error): | |
| Prompt("{ \"error\" : \"\(error)\" }") // model reads {"error": "..."} as the tool's result | |
| } | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import Foundation | |
| import FoundationModels | |
| public struct PendingToolCall: Hashable { | |
| let id: String | |
| let name: String | |
| } | |
| public actor ToolApprovalGate { | |
| enum ResolutionError: Error { | |
| case noPendingApproval(callID: ToolCallID) | |
| case duplicatePendingApproval(callID: ToolCallID) | |
| case duplicatePendingWait(callID: ToolCallID) | |
| } | |
| let onApprovalRequired: (PendingToolCall) async -> Void | |
| init(onApprovalRequired: @escaping (PendingToolCall) async -> Void) { | |
| self.onApprovalRequired = onApprovalRequired | |
| } | |
| typealias ToolCall = Transcript.ToolCall | |
| public typealias ToolCallID = String | |
| private var pending: [ToolCallID: PendingApproval] = [:] | |
| private enum PendingApproval { | |
| case pending | |
| case waiting(CheckedContinuation<ApprovalDecision, Never>) | |
| case resolved(ApprovalDecision) | |
| } | |
| /// The result of waiting on a ``ToolApprovalGate`` request. | |
| public enum ApprovalDecision: Sendable, Equatable { | |
| /// The pending tool call was approved and may execute. | |
| case approved | |
| /// The pending tool call was denied or cancelled. | |
| case denied(reason: String) | |
| static let cancelledReason = "cancelled" | |
| } | |
| /// Marks the tool call as pending before the approval-required event is emitted. | |
| public func register( | |
| call: PendingToolCall, | |
| ) throws { | |
| let callID = call.id | |
| guard pending[callID] == nil else { | |
| throw ResolutionError.duplicatePendingApproval(callID: callID) | |
| } | |
| pending[callID] = .pending | |
| Task { | |
| await onApprovalRequired(call) | |
| } | |
| } | |
| /// Suspends until the caller approves, denies, or cancels the pending tool call. | |
| public func waitForResolution(callID: ToolCallID) async throws -> ApprovalDecision { | |
| guard let current = pending[callID] else { | |
| throw ResolutionError.noPendingApproval(callID: callID) | |
| } | |
| switch current { | |
| case .pending: | |
| return await withCheckedContinuation { continuation in | |
| pending[callID] = .waiting(continuation) | |
| } | |
| case .waiting: | |
| throw ResolutionError.duplicatePendingWait(callID: callID) | |
| case .resolved(let resolution): | |
| pending.removeValue(forKey: callID) | |
| return resolution | |
| } | |
| } | |
| /// Cancels a pending tool call if one exists. | |
| public func cancel(callID: ToolCallID) { | |
| guard let current = pending[callID] else { | |
| return | |
| } | |
| switch current { | |
| case .pending: | |
| pending[callID] = .resolved(.denied(reason: ApprovalDecision.cancelledReason)) | |
| case .waiting(let continuation): | |
| pending.removeValue(forKey: callID) | |
| continuation.resume( | |
| returning: .denied(reason: ApprovalDecision.cancelledReason), | |
| ) | |
| case .resolved: | |
| return | |
| } | |
| } | |
| /// Approves a pending tool call. | |
| public func approve(callID: ToolCallID) throws { | |
| try resolve(callID: callID, as: .approved) | |
| } | |
| /// Denies a pending tool call. | |
| public func deny(callID: ToolCallID, reason: String) throws { | |
| try resolve(callID: callID, as: .denied(reason: reason)) | |
| } | |
| private func resolve( | |
| callID: ToolCallID, | |
| as resolution: ApprovalDecision, | |
| ) throws { | |
| guard let current = pending[callID] else { | |
| throw ResolutionError.noPendingApproval(callID: callID) | |
| } | |
| switch current { | |
| case .pending: | |
| pending[callID] = .resolved(resolution) | |
| case .waiting(let continuation): | |
| pending.removeValue(forKey: callID) | |
| continuation.resume(returning: resolution) | |
| case .resolved: | |
| pending[callID] = .resolved(resolution) | |
| } | |
| } | |
| } |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note, there is absolutely no need for GatedTool<T: Tool> to keep the closures. It is just a result of me quickly putting it together from a type-eraser wrapper that wasn't generic. Should update to keep the tool instance.