Skip to content

Instantly share code, notes, and snippets.

@airhorns
Created November 6, 2024 23:53
Show Gist options
  • Save airhorns/8de013ecede29bac3258cf1d72983fce to your computer and use it in GitHub Desktop.
Save airhorns/8de013ecede29bac3258cf1d72983fce to your computer and use it in GitHub Desktop.
/* eslint-disable no-console */
import type { AgentAction, AgentFinish } from "@langchain/core/agents";
import { BaseCallbackHandler } from "@langchain/core/callbacks/base";
import type { Document } from "@langchain/core/documents";
import type { Serialized } from "@langchain/core/load/serializable";
import {
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
type BaseMessageFields,
type MessageContent,
} from "@langchain/core/messages";
import type { Generation, LLMResult } from "@langchain/core/outputs";
import type { ChainValues } from "@langchain/core/utils/types";
import { startSpan, type Logger, type Span } from "braintrust";
export type LlmMessage = {
role: string;
content: BaseMessageFields["content"];
additional_kwargs?: BaseMessageFields["additional_kwargs"];
};
export type AnonymousLlmMessage = {
content: BaseMessageFields["content"];
additional_kwargs?: BaseMessageFields["additional_kwargs"];
};
/**
* A callback handler that logs to Braintrust.
*/
export class BraintrustCallbackHandler<IsAsyncFlush extends boolean = false> extends BaseCallbackHandler {
name = "BraintrustCallbackHandler";
debugEnabled: boolean = false;
private spansByRunId: Record<string, Span> = {};
constructor(readonly logger: Logger<IsAsyncFlush> | Span) {
super();
}
debug(enabled: boolean = true): void {
this.debugEnabled = enabled;
}
_log(message: any): void {
if (this.debugEnabled) {
console.log("[braintrust-langchain]", message);
}
}
handleRetrieverError(err: any, runId: string, parentRunId?: string | undefined) {
this._log(`Retriever error: ${err} with ID: ${runId}`);
const span = this.spansByRunId[runId];
if (!span) {
this._log(`handleRetrieverError: No span found for runId: ${runId}`);
return;
}
span.log({ error: err.toString() });
span.end({});
delete this.spansByRunId[runId];
}
async handleChainStart(
chain: Serialized,
inputs: ChainValues,
runId: string,
parentRunId?: string | undefined,
tags?: string[] | undefined,
metadata?: Record<string, unknown> | undefined,
runType?: string,
name?: string
): Promise<void> {
this._log(`Chain start with Id: ${runId}`);
const runName = name ?? chain.id.at(-1)?.toString() ?? "Langchain Run";
const parent = parentRunId ? this.spansByRunId[parentRunId] : undefined;
const span = startSpan({ name: runName, event: { input: inputs }, parent: parent?.id });
this.spansByRunId[runId] = span;
}
async handleAgentAction(action: AgentAction, runId?: string, parentRunId?: string): Promise<void> {
this._log(`Agent action with ID: ${runId}`);
const parent = parentRunId ? this.spansByRunId[parentRunId] : undefined;
const span = startSpan({ name: action.tool, event: { input: action }, parent: parent?.id });
span.end();
}
async handleAgentEnd(action: AgentFinish, runId: string, parentRunId?: string): Promise<void> {
this._log(`Agent finish with ID: ${runId}`);
const span = this.spansByRunId[runId];
if (!span) {
this._log(`handleAgentEnd: No span found for runId: ${runId}`);
return;
}
span.log({ output: action });
span.end();
delete this.spansByRunId[runId];
}
async handleChainError(err: any, runId: string, parentRunId?: string | undefined): Promise<void> {
this._log(`Chain error: ${err} with ID: ${runId}`);
const span = this.spansByRunId[runId];
if (!span) {
this._log(`handleChainError: No span found for runId: ${runId}`);
return;
}
span.log({ error: err.toString() });
span.end();
delete this.spansByRunId[runId];
}
async handleGenerationStart(
llm: Serialized,
messages: (LlmMessage | MessageContent | AnonymousLlmMessage)[],
runId: string,
parentRunId?: string | undefined,
extraParams?: Record<string, unknown> | undefined,
tags?: string[] | undefined,
metadata?: Record<string, unknown> | undefined,
name?: string
): Promise<void> {
this._log(`Generation start with ID: ${runId}`);
const runName = name ?? llm.id.at(-1)?.toString() ?? "langchain<unknown>";
const modelParameters: Record<string, any> = {};
const invocationParams = extraParams?.["invocation_params"];
for (const [key, value] of Object.entries({
temperature: (invocationParams as any)?.temperature,
max_tokens: (invocationParams as any)?.max_tokens,
top_p: (invocationParams as any)?.top_p,
frequency_penalty: (invocationParams as any)?.frequency_penalty,
presence_penalty: (invocationParams as any)?.presence_penalty,
request_timeout: (invocationParams as any)?.request_timeout,
})) {
if (value !== undefined && value !== null) {
modelParameters[key] = value;
}
}
interface InvocationParams {
_type?: string;
model?: string;
model_name?: string;
repo_id?: string;
}
let extractedModelName: string | undefined;
if (extraParams) {
const invocationParamsModelName = (extraParams.invocation_params as InvocationParams).model;
const metadataModelName = metadata && "ls_model_name" in metadata ? (metadata["ls_model_name"] as string) : undefined;
extractedModelName = invocationParamsModelName ?? metadataModelName;
}
const parent = parentRunId ? this.spansByRunId[parentRunId] : undefined;
const span = startSpan({
name: runName,
type: "llm",
parent: parent?.id,
spanAttributes: {
model: extractedModelName,
modelParameters: modelParameters,
},
});
span.log({ input: messages });
this.spansByRunId[runId] = span;
}
async handleChatModelStart(
llm: Serialized,
messages: BaseMessage[][],
runId: string,
parentRunId?: string | undefined,
extraParams?: Record<string, unknown> | undefined,
tags?: string[] | undefined,
metadata?: Record<string, unknown> | undefined,
name?: string
): Promise<void> {
this._log(`Chat model start with ID: ${runId}`);
const prompts = messages.flatMap((message) => message.map((m) => this.extractChatMessageContent(m)));
return await this.handleGenerationStart(llm, prompts, runId, parentRunId, extraParams, tags, metadata, name);
}
async handleChainEnd(outputs: ChainValues, runId: string, parentRunId?: string | undefined): Promise<void> {
this._log(`Chain end with ID: ${runId}`);
const span = this.spansByRunId[runId];
if (!span) {
this._log(`handleChainEnd: No span found for runId: ${runId}`);
return;
}
span.log({ output: outputs });
span.end();
delete this.spansByRunId[runId];
}
async handleLLMStart(
llm: Serialized,
prompts: string[],
runId: string,
parentRunId?: string | undefined,
extraParams?: Record<string, unknown> | undefined,
tags?: string[] | undefined,
metadata?: Record<string, unknown> | undefined,
name?: string
): Promise<void> {
this._log(`LLM start with ID: ${runId}`);
return await this.handleGenerationStart(llm, prompts, runId, parentRunId, extraParams, tags, metadata, name);
}
async handleToolStart(
tool: Serialized,
input: string,
runId: string,
parentRunId?: string | undefined,
tags?: string[] | undefined,
metadata?: Record<string, unknown> | undefined,
name?: string
): Promise<void> {
this._log(`Tool start with ID: ${runId}`);
const parent = parentRunId ? this.spansByRunId[parentRunId] : undefined;
const span = startSpan({ name: name ?? tool.id.at(-1)?.toString(), type: "tool", parent: parent?.id });
span.log({ input });
this.spansByRunId[runId] = span;
}
async handleRetrieverStart(
retriever: Serialized,
query: string,
runId: string,
parentRunId?: string | undefined,
tags?: string[] | undefined,
metadata?: Record<string, unknown> | undefined,
name?: string
): Promise<void> {
this._log(`Retriever start with ID: ${runId}`);
const parent = parentRunId ? this.spansByRunId[parentRunId] : undefined;
const span = startSpan({ name: name ?? retriever.id.at(-1)?.toString(), type: "function", parent: parent?.id });
span.log({ input: query });
this.spansByRunId[runId] = span;
}
async handleRetrieverEnd(documents: Document<Record<string, any>>[], runId: string, parentRunId?: string | undefined): Promise<void> {
this._log(`Retriever end with ID: ${runId}`);
const span = this.spansByRunId[runId];
if (!span) {
this._log(`handleRetrieverEnd: No span found for runId: ${runId}`);
return;
}
span.log({ output: documents });
span.end();
delete this.spansByRunId[runId];
}
async handleToolEnd(output: string, runId: string, parentRunId?: string | undefined): Promise<void> {
this._log(`Tool end with ID: ${runId}`);
const span = this.spansByRunId[runId];
if (!span) {
this._log(`handleToolEnd: No span found for runId: ${runId}`);
return;
}
span.log({ output });
span.end();
delete this.spansByRunId[runId];
}
async handleToolError(err: any, runId: string, parentRunId?: string | undefined): Promise<void> {
this._log(`Tool error ${err} with ID: ${runId}`);
const span = this.spansByRunId[runId];
if (!span) {
this._log(`handleToolEnd: No span found for runId: ${runId}`);
return;
}
span.log({ error: err.toString() });
span.end({});
delete this.spansByRunId[runId];
}
async handleLLMEnd(output: LLMResult, runId: string, parentRunId?: string | undefined): Promise<void> {
this._log(`LLM end with ID: ${runId}`);
const span = this.spansByRunId[runId];
if (!span) {
this._log(`handleLLMEnd: No span found for runId: ${runId}`);
return;
}
const lastResponse = output.generations[output.generations.length - 1][output.generations[output.generations.length - 1].length - 1];
const llmUsage = output.llmOutput?.["tokenUsage"] ?? this.extractUsageMetadata(lastResponse);
const extractedOutput =
"message" in lastResponse && lastResponse["message"] instanceof BaseMessage
? this.extractChatMessageContent(lastResponse["message"])
: lastResponse.text;
span.log({ output: extractedOutput, metrics: llmUsage });
span.end();
}
/** Not all models supports tokenUsage in llmOutput, can use AIMessage.usage_metadata instead */
private extractUsageMetadata(
generation: Generation
): { prompt_tokens: number; completion_tokens: number; total_tokens: number } | undefined {
try {
const usageMetadata =
"message" in generation && (generation["message"] instanceof AIMessage || generation["message"] instanceof AIMessageChunk)
? generation["message"].usage_metadata
: undefined;
if (!usageMetadata) {
return;
}
return {
prompt_tokens: usageMetadata.input_tokens,
completion_tokens: usageMetadata.output_tokens,
total_tokens: usageMetadata.total_tokens,
};
} catch (err) {
this._log(`Error extracting usage metadata: ${err}`);
return;
}
}
private extractChatMessageContent(message: BaseMessage): LlmMessage | AnonymousLlmMessage | MessageContent {
let response = undefined;
if (message instanceof HumanMessage) {
response = { content: message.content, role: "user" };
} else if (message instanceof ChatMessage) {
response = { content: message.content, role: message.name };
} else if (message instanceof AIMessage) {
response = { content: message.content, role: "assistant" };
} else if (message instanceof SystemMessage) {
response = { content: message.content, role: "system" };
} else if (message instanceof FunctionMessage) {
response = { content: message.content, additional_kwargs: message.additional_kwargs, role: message.name };
} else if (message instanceof ToolMessage) {
response = { content: message.content, additional_kwargs: message.additional_kwargs, role: message.name };
} else if (!message.name) {
response = { content: message.content };
} else {
response = {
role: message.name,
content: message.content,
};
}
if (message.additional_kwargs.function_call || message.additional_kwargs.tool_calls) {
return { ...response, additional_kwargs: message.additional_kwargs };
}
return response;
}
async handleLLMError(err: any, runId: string, parentRunId?: string | undefined): Promise<void> {
this._log(`LLM error ${err} with ID: ${runId}`);
const span = this.spansByRunId[runId];
if (span) {
span.log({ error: err.toString() });
span.end();
delete this.spansByRunId[runId];
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment