Created
November 6, 2024 23:53
-
-
Save airhorns/8de013ecede29bac3258cf1d72983fce to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* eslint-disable no-console */ | |
import type { AgentAction, AgentFinish } from "@langchain/core/agents"; | |
import { BaseCallbackHandler } from "@langchain/core/callbacks/base"; | |
import type { Document } from "@langchain/core/documents"; | |
import type { Serialized } from "@langchain/core/load/serializable"; | |
import { | |
AIMessage, | |
AIMessageChunk, | |
BaseMessage, | |
ChatMessage, | |
FunctionMessage, | |
HumanMessage, | |
SystemMessage, | |
ToolMessage, | |
type BaseMessageFields, | |
type MessageContent, | |
} from "@langchain/core/messages"; | |
import type { Generation, LLMResult } from "@langchain/core/outputs"; | |
import type { ChainValues } from "@langchain/core/utils/types"; | |
import { startSpan, type Logger, type Span } from "braintrust"; | |
export type LlmMessage = { | |
role: string; | |
content: BaseMessageFields["content"]; | |
additional_kwargs?: BaseMessageFields["additional_kwargs"]; | |
}; | |
export type AnonymousLlmMessage = { | |
content: BaseMessageFields["content"]; | |
additional_kwargs?: BaseMessageFields["additional_kwargs"]; | |
}; | |
/** | |
* A callback handler that logs to Braintrust. | |
*/ | |
export class BraintrustCallbackHandler<IsAsyncFlush extends boolean = false> extends BaseCallbackHandler { | |
name = "BraintrustCallbackHandler"; | |
debugEnabled: boolean = false; | |
private spansByRunId: Record<string, Span> = {}; | |
constructor(readonly logger: Logger<IsAsyncFlush> | Span) { | |
super(); | |
} | |
debug(enabled: boolean = true): void { | |
this.debugEnabled = enabled; | |
} | |
_log(message: any): void { | |
if (this.debugEnabled) { | |
console.log("[braintrust-langchain]", message); | |
} | |
} | |
handleRetrieverError(err: any, runId: string, parentRunId?: string | undefined) { | |
this._log(`Retriever error: ${err} with ID: ${runId}`); | |
const span = this.spansByRunId[runId]; | |
if (!span) { | |
this._log(`handleRetrieverError: No span found for runId: ${runId}`); | |
return; | |
} | |
span.log({ error: err.toString() }); | |
span.end({}); | |
delete this.spansByRunId[runId]; | |
} | |
async handleChainStart( | |
chain: Serialized, | |
inputs: ChainValues, | |
runId: string, | |
parentRunId?: string | undefined, | |
tags?: string[] | undefined, | |
metadata?: Record<string, unknown> | undefined, | |
runType?: string, | |
name?: string | |
): Promise<void> { | |
this._log(`Chain start with Id: ${runId}`); | |
const runName = name ?? chain.id.at(-1)?.toString() ?? "Langchain Run"; | |
const parent = parentRunId ? this.spansByRunId[parentRunId] : undefined; | |
const span = startSpan({ name: runName, event: { input: inputs }, parent: parent?.id }); | |
this.spansByRunId[runId] = span; | |
} | |
async handleAgentAction(action: AgentAction, runId?: string, parentRunId?: string): Promise<void> { | |
this._log(`Agent action with ID: ${runId}`); | |
const parent = parentRunId ? this.spansByRunId[parentRunId] : undefined; | |
const span = startSpan({ name: action.tool, event: { input: action }, parent: parent?.id }); | |
span.end(); | |
} | |
async handleAgentEnd(action: AgentFinish, runId: string, parentRunId?: string): Promise<void> { | |
this._log(`Agent finish with ID: ${runId}`); | |
const span = this.spansByRunId[runId]; | |
if (!span) { | |
this._log(`handleAgentEnd: No span found for runId: ${runId}`); | |
return; | |
} | |
span.log({ output: action }); | |
span.end(); | |
delete this.spansByRunId[runId]; | |
} | |
async handleChainError(err: any, runId: string, parentRunId?: string | undefined): Promise<void> { | |
this._log(`Chain error: ${err} with ID: ${runId}`); | |
const span = this.spansByRunId[runId]; | |
if (!span) { | |
this._log(`handleChainError: No span found for runId: ${runId}`); | |
return; | |
} | |
span.log({ error: err.toString() }); | |
span.end(); | |
delete this.spansByRunId[runId]; | |
} | |
async handleGenerationStart( | |
llm: Serialized, | |
messages: (LlmMessage | MessageContent | AnonymousLlmMessage)[], | |
runId: string, | |
parentRunId?: string | undefined, | |
extraParams?: Record<string, unknown> | undefined, | |
tags?: string[] | undefined, | |
metadata?: Record<string, unknown> | undefined, | |
name?: string | |
): Promise<void> { | |
this._log(`Generation start with ID: ${runId}`); | |
const runName = name ?? llm.id.at(-1)?.toString() ?? "langchain<unknown>"; | |
const modelParameters: Record<string, any> = {}; | |
const invocationParams = extraParams?.["invocation_params"]; | |
for (const [key, value] of Object.entries({ | |
temperature: (invocationParams as any)?.temperature, | |
max_tokens: (invocationParams as any)?.max_tokens, | |
top_p: (invocationParams as any)?.top_p, | |
frequency_penalty: (invocationParams as any)?.frequency_penalty, | |
presence_penalty: (invocationParams as any)?.presence_penalty, | |
request_timeout: (invocationParams as any)?.request_timeout, | |
})) { | |
if (value !== undefined && value !== null) { | |
modelParameters[key] = value; | |
} | |
} | |
interface InvocationParams { | |
_type?: string; | |
model?: string; | |
model_name?: string; | |
repo_id?: string; | |
} | |
let extractedModelName: string | undefined; | |
if (extraParams) { | |
const invocationParamsModelName = (extraParams.invocation_params as InvocationParams).model; | |
const metadataModelName = metadata && "ls_model_name" in metadata ? (metadata["ls_model_name"] as string) : undefined; | |
extractedModelName = invocationParamsModelName ?? metadataModelName; | |
} | |
const parent = parentRunId ? this.spansByRunId[parentRunId] : undefined; | |
const span = startSpan({ | |
name: runName, | |
type: "llm", | |
parent: parent?.id, | |
spanAttributes: { | |
model: extractedModelName, | |
modelParameters: modelParameters, | |
}, | |
}); | |
span.log({ input: messages }); | |
this.spansByRunId[runId] = span; | |
} | |
async handleChatModelStart( | |
llm: Serialized, | |
messages: BaseMessage[][], | |
runId: string, | |
parentRunId?: string | undefined, | |
extraParams?: Record<string, unknown> | undefined, | |
tags?: string[] | undefined, | |
metadata?: Record<string, unknown> | undefined, | |
name?: string | |
): Promise<void> { | |
this._log(`Chat model start with ID: ${runId}`); | |
const prompts = messages.flatMap((message) => message.map((m) => this.extractChatMessageContent(m))); | |
return await this.handleGenerationStart(llm, prompts, runId, parentRunId, extraParams, tags, metadata, name); | |
} | |
async handleChainEnd(outputs: ChainValues, runId: string, parentRunId?: string | undefined): Promise<void> { | |
this._log(`Chain end with ID: ${runId}`); | |
const span = this.spansByRunId[runId]; | |
if (!span) { | |
this._log(`handleChainEnd: No span found for runId: ${runId}`); | |
return; | |
} | |
span.log({ output: outputs }); | |
span.end(); | |
delete this.spansByRunId[runId]; | |
} | |
async handleLLMStart( | |
llm: Serialized, | |
prompts: string[], | |
runId: string, | |
parentRunId?: string | undefined, | |
extraParams?: Record<string, unknown> | undefined, | |
tags?: string[] | undefined, | |
metadata?: Record<string, unknown> | undefined, | |
name?: string | |
): Promise<void> { | |
this._log(`LLM start with ID: ${runId}`); | |
return await this.handleGenerationStart(llm, prompts, runId, parentRunId, extraParams, tags, metadata, name); | |
} | |
async handleToolStart( | |
tool: Serialized, | |
input: string, | |
runId: string, | |
parentRunId?: string | undefined, | |
tags?: string[] | undefined, | |
metadata?: Record<string, unknown> | undefined, | |
name?: string | |
): Promise<void> { | |
this._log(`Tool start with ID: ${runId}`); | |
const parent = parentRunId ? this.spansByRunId[parentRunId] : undefined; | |
const span = startSpan({ name: name ?? tool.id.at(-1)?.toString(), type: "tool", parent: parent?.id }); | |
span.log({ input }); | |
this.spansByRunId[runId] = span; | |
} | |
async handleRetrieverStart( | |
retriever: Serialized, | |
query: string, | |
runId: string, | |
parentRunId?: string | undefined, | |
tags?: string[] | undefined, | |
metadata?: Record<string, unknown> | undefined, | |
name?: string | |
): Promise<void> { | |
this._log(`Retriever start with ID: ${runId}`); | |
const parent = parentRunId ? this.spansByRunId[parentRunId] : undefined; | |
const span = startSpan({ name: name ?? retriever.id.at(-1)?.toString(), type: "function", parent: parent?.id }); | |
span.log({ input: query }); | |
this.spansByRunId[runId] = span; | |
} | |
async handleRetrieverEnd(documents: Document<Record<string, any>>[], runId: string, parentRunId?: string | undefined): Promise<void> { | |
this._log(`Retriever end with ID: ${runId}`); | |
const span = this.spansByRunId[runId]; | |
if (!span) { | |
this._log(`handleRetrieverEnd: No span found for runId: ${runId}`); | |
return; | |
} | |
span.log({ output: documents }); | |
span.end(); | |
delete this.spansByRunId[runId]; | |
} | |
async handleToolEnd(output: string, runId: string, parentRunId?: string | undefined): Promise<void> { | |
this._log(`Tool end with ID: ${runId}`); | |
const span = this.spansByRunId[runId]; | |
if (!span) { | |
this._log(`handleToolEnd: No span found for runId: ${runId}`); | |
return; | |
} | |
span.log({ output }); | |
span.end(); | |
delete this.spansByRunId[runId]; | |
} | |
async handleToolError(err: any, runId: string, parentRunId?: string | undefined): Promise<void> { | |
this._log(`Tool error ${err} with ID: ${runId}`); | |
const span = this.spansByRunId[runId]; | |
if (!span) { | |
this._log(`handleToolEnd: No span found for runId: ${runId}`); | |
return; | |
} | |
span.log({ error: err.toString() }); | |
span.end({}); | |
delete this.spansByRunId[runId]; | |
} | |
async handleLLMEnd(output: LLMResult, runId: string, parentRunId?: string | undefined): Promise<void> { | |
this._log(`LLM end with ID: ${runId}`); | |
const span = this.spansByRunId[runId]; | |
if (!span) { | |
this._log(`handleLLMEnd: No span found for runId: ${runId}`); | |
return; | |
} | |
const lastResponse = output.generations[output.generations.length - 1][output.generations[output.generations.length - 1].length - 1]; | |
const llmUsage = output.llmOutput?.["tokenUsage"] ?? this.extractUsageMetadata(lastResponse); | |
const extractedOutput = | |
"message" in lastResponse && lastResponse["message"] instanceof BaseMessage | |
? this.extractChatMessageContent(lastResponse["message"]) | |
: lastResponse.text; | |
span.log({ output: extractedOutput, metrics: llmUsage }); | |
span.end(); | |
} | |
/** Not all models supports tokenUsage in llmOutput, can use AIMessage.usage_metadata instead */ | |
private extractUsageMetadata( | |
generation: Generation | |
): { prompt_tokens: number; completion_tokens: number; total_tokens: number } | undefined { | |
try { | |
const usageMetadata = | |
"message" in generation && (generation["message"] instanceof AIMessage || generation["message"] instanceof AIMessageChunk) | |
? generation["message"].usage_metadata | |
: undefined; | |
if (!usageMetadata) { | |
return; | |
} | |
return { | |
prompt_tokens: usageMetadata.input_tokens, | |
completion_tokens: usageMetadata.output_tokens, | |
total_tokens: usageMetadata.total_tokens, | |
}; | |
} catch (err) { | |
this._log(`Error extracting usage metadata: ${err}`); | |
return; | |
} | |
} | |
private extractChatMessageContent(message: BaseMessage): LlmMessage | AnonymousLlmMessage | MessageContent { | |
let response = undefined; | |
if (message instanceof HumanMessage) { | |
response = { content: message.content, role: "user" }; | |
} else if (message instanceof ChatMessage) { | |
response = { content: message.content, role: message.name }; | |
} else if (message instanceof AIMessage) { | |
response = { content: message.content, role: "assistant" }; | |
} else if (message instanceof SystemMessage) { | |
response = { content: message.content, role: "system" }; | |
} else if (message instanceof FunctionMessage) { | |
response = { content: message.content, additional_kwargs: message.additional_kwargs, role: message.name }; | |
} else if (message instanceof ToolMessage) { | |
response = { content: message.content, additional_kwargs: message.additional_kwargs, role: message.name }; | |
} else if (!message.name) { | |
response = { content: message.content }; | |
} else { | |
response = { | |
role: message.name, | |
content: message.content, | |
}; | |
} | |
if (message.additional_kwargs.function_call || message.additional_kwargs.tool_calls) { | |
return { ...response, additional_kwargs: message.additional_kwargs }; | |
} | |
return response; | |
} | |
async handleLLMError(err: any, runId: string, parentRunId?: string | undefined): Promise<void> { | |
this._log(`LLM error ${err} with ID: ${runId}`); | |
const span = this.spansByRunId[runId]; | |
if (span) { | |
span.log({ error: err.toString() }); | |
span.end(); | |
delete this.spansByRunId[runId]; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment