|
/// <reference types="@types/dom-chromium-ai" /> |
|
|
|
/** |
|
* Polyfill for the Prompt API tool calling functionality. |
|
* |
|
* This polyfill enables tool use with the Prompt API by intercepting the `tools` option |
|
* and implementing tool-calling logic manually. This is necessary because native browser |
|
* implementations don't yet support the `tools` parameter defined in the spec. |
|
* |
|
* Spec: https://webmachinelearning.github.io/prompt-api/ |
|
* |
|
* NOTE: This is a workaround implementation. Native implementations will handle tool calls |
|
* internally without needing JSON parsing or manual tool execution. |
|
*/ |
|
|
|
// A weak map to associate language model sessions with their registered tools. |
|
// Using a WeakMap ensures that once a session is garbage-collected, the associated tools are as well. |
|
const sessionTools = new WeakMap<LanguageModel, LanguageModelTool[]>() |
|
|
|
// Store the original static `create` method from the LanguageModel interface. |
|
const originalCreate = LanguageModel.create |
|
|
|
/** |
|
* A JSON structure that the model is instructed to use when it decides to call a tool. |
|
* This is a polyfill-specific format, not part of the official spec. |
|
*/ |
|
interface ToolCallRequest { |
|
tool_calls: Array<{ |
|
name: string |
|
arguments: Record<string, any> |
|
}> |
|
} |
|
|
|
/** |
|
* Helper function to detect if a response contains a tool call request. |
|
* Handles both plain JSON and markdown-wrapped JSON. |
|
*/ |
|
function parseToolCallRequest(response: string): ToolCallRequest | null { |
|
let jsonToParse = response.trim() |
|
|
|
// Check if the response is wrapped in a markdown code block and extract the JSON. |
|
const markdownMatch = response.match(/```(?:json)?\s*\n([\s\S]*?)\n```/) |
|
if (markdownMatch && markdownMatch[1]) { |
|
jsonToParse = markdownMatch[1].trim() |
|
} |
|
|
|
try { |
|
const parsedResponse = JSON.parse(jsonToParse) |
|
if ( |
|
parsedResponse.tool_calls && |
|
Array.isArray(parsedResponse.tool_calls) && |
|
parsedResponse.tool_calls.length > 0 |
|
) { |
|
return parsedResponse as ToolCallRequest |
|
} |
|
} catch (e) { |
|
// Not valid JSON or not a tool call request |
|
} |
|
|
|
return null |
|
} |
|
|
|
/** |
|
* Helper function to build the system prompt that instructs the model how to use tools. |
|
*/ |
|
function buildToolSystemPrompt(tools: LanguageModelTool[]): string { |
|
const toolDescriptions = tools |
|
.map((tool) => { |
|
const schema = JSON.stringify(tool.inputSchema, null, 2) |
|
return `- ${tool.name}: ${tool.description}\n Input schema: ${schema}` |
|
}) |
|
.join('\n') |
|
|
|
return `You are a helpful assistant. You have access to the following tools. |
|
To use a tool, respond with a JSON object with a 'tool_calls' key, like this: |
|
{"tool_calls": [{"name": "tool_name", "arguments": {"arg1": "value1", "arg2": "value2"}}]} |
|
|
|
Available tools: |
|
${toolDescriptions} |
|
|
|
When you need to use a tool, respond ONLY with the JSON tool call request. After receiving the tool results, provide your final answer to the user.` |
|
} |
|
|
|
/** |
|
* Helper function to prepend system prompt to the first user message. |
|
* This is necessary because the prompt() and promptStreaming() methods don't support system roles. |
|
*/ |
|
function prependSystemPromptToMessages( |
|
messages: LanguageModelMessage[], |
|
systemPrompt: string |
|
): LanguageModelMessage[] { |
|
const allPrompts = Array.isArray(messages) ? [...messages] : [messages] |
|
|
|
const firstUserMessageIndex = allPrompts.findIndex((p) => p.role === 'user') |
|
|
|
if (firstUserMessageIndex !== -1) { |
|
const firstUserMessage = allPrompts[firstUserMessageIndex] |
|
const originalContent = |
|
typeof firstUserMessage.content === 'string' |
|
? firstUserMessage.content |
|
: JSON.stringify(firstUserMessage.content) |
|
|
|
allPrompts[firstUserMessageIndex] = { |
|
...firstUserMessage, |
|
content: `${systemPrompt}\n\n${originalContent}`, |
|
} |
|
} else { |
|
// If there's no user message, create one with the system prompt. |
|
allPrompts.unshift({ |
|
role: 'user', |
|
content: systemPrompt, |
|
}) |
|
} |
|
|
|
return allPrompts |
|
} |
|
|
|
/** |
|
* Execute tool calls and return formatted results. |
|
*/ |
|
async function executeToolCalls( |
|
toolCallRequest: ToolCallRequest, |
|
tools: LanguageModelTool[] |
|
): Promise<string> { |
|
const toolResults = [] |
|
|
|
for (const call of toolCallRequest.tool_calls) { |
|
const tool = tools.find((t) => t.name === call.name) |
|
|
|
if (!tool) { |
|
toolResults.push({ |
|
tool_call_name: call.name, |
|
result: `Error: Tool '${call.name}' not found.`, |
|
}) |
|
continue |
|
} |
|
|
|
try { |
|
// Execute the tool with the provided arguments |
|
const result = await tool.execute(call.arguments) |
|
toolResults.push({ |
|
tool_call_name: call.name, |
|
result, |
|
}) |
|
} catch (error) { |
|
toolResults.push({ |
|
tool_call_name: call.name, |
|
result: `Error executing tool '${call.name}': ${error instanceof Error ? error.message : String(error)}`, |
|
}) |
|
} |
|
} |
|
|
|
return `Tool results: ${JSON.stringify(toolResults)}` |
|
} |
|
|
|
/** |
|
* Wraps LanguageModel.create to intercept the `tools` parameter. |
|
* This allows us to store the tools and later apply the tool-calling logic |
|
* to the session's prompt methods. |
|
*/ |
|
LanguageModel.create = async function ( |
|
options?: LanguageModelCreateOptions |
|
): Promise<LanguageModel> { |
|
const tools = options?.tools ?? [] |
|
|
|
// If no tools are provided, just call the original create method. |
|
if (tools.length === 0) { |
|
return originalCreate(options) |
|
} |
|
|
|
// Create a copy of options and remove tools, as the native implementation |
|
// does not yet support their execution. |
|
const newOptions: LanguageModelCreateOptions = { ...options } |
|
delete newOptions.tools |
|
|
|
// Call the original create method. |
|
const session = await originalCreate(newOptions) |
|
|
|
// Store the tools associated with the newly created session. |
|
sessionTools.set(session, tools) |
|
|
|
const systemPrompt = buildToolSystemPrompt(tools) |
|
|
|
// Wrap the prompt method to handle the tool-calling logic. |
|
const originalPrompt = session.prompt.bind(session) |
|
session.prompt = async ( |
|
input: LanguageModelPrompt, |
|
promptOptions?: LanguageModelPromptOptions |
|
): Promise<string> => { |
|
// Normalize input to array of messages |
|
let allPrompts: LanguageModelMessage[] = Array.isArray(input) |
|
? [...input] |
|
: [{ role: 'user', content: input }] |
|
|
|
// Prepend system prompt with tool instructions |
|
allPrompts = prependSystemPromptToMessages(allPrompts, systemPrompt) |
|
|
|
// Tool-calling loop: continue until we get a non-tool-call response |
|
let maxIterations = 10 // Prevent infinite loops |
|
let iteration = 0 |
|
|
|
while (iteration < maxIterations) { |
|
iteration++ |
|
|
|
const response = await originalPrompt(allPrompts, promptOptions) |
|
const toolCallRequest = parseToolCallRequest(response) |
|
|
|
if (!toolCallRequest) { |
|
// No tool call detected, return the final response |
|
return response |
|
} |
|
|
|
// Tool call detected: add assistant message and execute tools |
|
allPrompts.push({ role: 'assistant', content: response }) |
|
|
|
const toolResultsMessage = await executeToolCalls(toolCallRequest, tools) |
|
|
|
// Add tool results to conversation history |
|
allPrompts.push({ |
|
role: 'user', |
|
content: toolResultsMessage, |
|
}) |
|
|
|
// Continue the loop to get the next response |
|
} |
|
|
|
throw new Error( |
|
`Tool calling loop exceeded maximum iterations (${maxIterations}). This may indicate a problem with the model or tool configuration.` |
|
) |
|
} |
|
|
|
// Wrap the promptStreaming method to handle the tool-calling logic. |
|
const originalPromptStreaming = session.promptStreaming.bind(session) |
|
session.promptStreaming = function ( |
|
input: LanguageModelPrompt, |
|
promptOptions?: LanguageModelPromptOptions |
|
): ReadableStream<string> { |
|
// Normalize input to array of messages |
|
let allPrompts: LanguageModelMessage[] = Array.isArray(input) |
|
? [...input] |
|
: [{ role: 'user', content: input }] |
|
|
|
// Prepend system prompt with tool instructions |
|
allPrompts = prependSystemPromptToMessages(allPrompts, systemPrompt) |
|
|
|
// Create a ReadableStream that handles the tool-calling loop |
|
return new ReadableStream<string>({ |
|
async start(controller) { |
|
try { |
|
let maxIterations = 10 // Prevent infinite loops |
|
let iteration = 0 |
|
|
|
// Tool-calling loop |
|
while (iteration < maxIterations) { |
|
iteration++ |
|
|
|
let response = '' |
|
const stream = originalPromptStreaming(allPrompts, promptOptions) |
|
const reader = stream.getReader() |
|
|
|
// Collect the full response from the stream |
|
// Note: We need to buffer the full response to detect tool calls |
|
// This is a limitation of the polyfill approach |
|
try { |
|
while (true) { |
|
const { done, value } = await reader.read() |
|
if (done) break |
|
response += value |
|
} |
|
} finally { |
|
reader.releaseLock() |
|
} |
|
|
|
const toolCallRequest = parseToolCallRequest(response) |
|
|
|
if (!toolCallRequest) { |
|
// No tool call detected, stream the final response |
|
// For streaming efficiency, we could chunk this, but for simplicity |
|
// we'll send it as a single chunk |
|
controller.enqueue(response) |
|
controller.close() |
|
return |
|
} |
|
|
|
// Tool call detected: add assistant message and execute tools |
|
allPrompts.push({ role: 'assistant', content: response }) |
|
|
|
const toolResultsMessage = await executeToolCalls( |
|
toolCallRequest, |
|
tools |
|
) |
|
|
|
// Add tool results to conversation history |
|
allPrompts.push({ |
|
role: 'user', |
|
content: toolResultsMessage, |
|
}) |
|
|
|
// Continue the loop to get the next response |
|
} |
|
|
|
// If we've exceeded max iterations, throw an error |
|
controller.error( |
|
new Error( |
|
`Tool calling loop exceeded maximum iterations (${maxIterations}). This may indicate a problem with the model or tool configuration.` |
|
) |
|
) |
|
} catch (error) { |
|
controller.error(error) |
|
} |
|
}, |
|
}) |
|
} |
|
|
|
return session |
|
} |