Created
May 24, 2023 16:17
-
-
Save emonidi/b257014de26497db506850df9998b2bb to your computer and use it in GitHub Desktop.
Langchain.js Naive ToT implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { models } from '../models' | |
import { LLMChain, SerializedLLMChain } from "langchain/chains"; | |
import { PromptTemplate } from "langchain/prompts"; | |
import {AgentExecutor, BaseSingleActionAgent, StoppingMethod } from "langchain/agents"; | |
import { CallbackManagerForChainRun, Callbacks } from "langchain/callbacks"; | |
import { BaseMultiActionAgent } from "langchain/dist/agents/agent"; | |
import { BaseMemory } from "langchain/memory"; | |
import { AgentAction, AgentFinish, ChainValues } from "langchain/schema"; | |
import { Tool } from "langchain/tools"; | |
import { OutputParser } from "./output-parser"; | |
import { FormatInstructionsOptions } from "langchain/schema/output_parser"; | |
(async () => { | |
const model = models["oa"]; | |
const thoughtPrompt = new PromptTemplate({ | |
template: "Given the current instruction: '{input}', generate {sizeLimit} different answers:", | |
inputVariables: ["input", "sizeLimit"] | |
}); | |
class ThoughtOutputParser extends OutputParser { | |
promptTemplate: PromptTemplate | |
constructor(fields: { promptTemplate: PromptTemplate }) { | |
super(fields.promptTemplate.template); | |
this.promptTemplate = fields.promptTemplate; | |
} | |
//@ts-ignore | |
async parse(text: string, callbacks?: Callbacks | undefined): Promise<any> { | |
console.log(text.split("\n").filter(item=>item !== '').map(item=>item.replace(/\d.\s/ig, ""))) | |
return text.split("\n").filter(item=>item !== '').map(item=>item.replace(/\d.\s/ig, "")); | |
} | |
//@ts-ignore | |
getFormatInstructions(options?: FormatInstructionsOptions | undefined): string { | |
throw new Error("Method not implemented."); | |
} | |
} | |
class EvaluatorOutputParser extends OutputParser { | |
} | |
const thoughtGenerator = new LLMChain({ | |
llm: model, | |
prompt: thoughtPrompt, | |
outputParser: new ThoughtOutputParser({ promptTemplate: thoughtPrompt }) | |
}); | |
const thoughtEvaluator = new LLMChain({ | |
llm: model, | |
prompt: new PromptTemplate({ | |
template: "Given the following instruction as a context: '{context}' and the current state of reasoning: '{state_text}', critically evaluate its relevance and accuracy as a float between 0 and 1, and NOTHING ELSE:", | |
inputVariables: ["state_text","context"] | |
}), | |
outputParser: new OutputParser("") | |
}) | |
interface ToT { | |
thought: string, | |
evaluation: number, | |
children?: ToT[] | undefined; | |
} | |
interface ToTInput { | |
sizeLimit: number // k, | |
stepLimit: number //T, | |
threshold: number //Vth | |
} | |
//@ts-ignore; | |
class ToTExecutor implements AgentExecutor { | |
agent: BaseSingleActionAgent | BaseMultiActionAgent; | |
tools: Tool[]; | |
totInput: ToTInput | |
returnIntermediateSteps: boolean; | |
maxIterations?: number | undefined; | |
earlyStoppingMethod: StoppingMethod; | |
memory?: BaseMemory | undefined; | |
evaluator: LLMChain<AgentAction|AgentFinish> | |
generator: LLMChain<AgentAction|AgentFinish> | |
constructor(fields: { | |
totInput: ToTInput, | |
evaluator: LLMChain<AgentAction|AgentFinish>, | |
generator: LLMChain<AgentAction|AgentFinish> | |
}) { | |
this.totInput = fields.totInput; | |
this.evaluator = fields.evaluator; | |
this.generator = fields.generator; | |
this.run = this.run.bind(this); | |
} | |
get inputKeys(): string[] { | |
throw new Error("Method not implemented."); | |
} | |
get outputKeys(): string[] { | |
throw new Error("Method not implemented."); | |
} | |
//@ts-ignore | |
_call(inputs: ChainValues, runManager?: CallbackManagerForChainRun | undefined): Promise<ChainValues> { | |
throw new Error("Method not implemented."); | |
} | |
//@ts-ignore | |
_chainType(): "agent_executor" { | |
throw new Error("Method not implemented."); | |
} | |
//@ts-ignore | |
serialize(): SerializedLLMChain { | |
throw new Error("Method not implemented."); | |
} | |
async run(input: any, _callbacks?: Callbacks | undefined): Promise<string> { | |
let output: ToT[]= []; | |
let context = input; | |
let totInput = this.totInput; | |
let { generator, evaluator } = this; | |
async function dfs(_stepLimit: number, _currentStep: number,input:any) { | |
if(_currentStep > totInput.stepLimit){ | |
const thought = await generator.call({ | |
input, | |
sizeLimit: 1 | |
}); | |
const evaluated = await evaluator.call({ state_text: thought.text[0], context }); | |
const evaluation = parseFloat(evaluated.text.log) | |
output.push({ | |
thought:thought.text[0], | |
evaluation | |
}) | |
return; | |
} | |
const thoughts = | |
await generator.call({ | |
input, | |
sizeLimit: totInput.sizeLimit | |
}); | |
for (let i = 0; i < thoughts.text.length; i++) { | |
const evaluated = await evaluator.call({ state_text: thoughts.text[i], context }); | |
const evaluation = parseFloat(evaluated.text.log.trim()) | |
console.log(`[step]:${_currentStep}`) | |
console.log(`[thought]:${thoughts.text[i]}`); | |
console.log(`[score]:${evaluation}`); | |
if(evaluation > totInput.threshold){ | |
await dfs(totInput.sizeLimit,_currentStep+1,thoughts.text[i]) | |
} | |
} | |
}; | |
await dfs(1, 0,input); | |
return Promise.resolve(JSON.stringify(output)); | |
} | |
//@ts-ignore | |
call(values: ChainValues, callbacks?: Callbacks | undefined): Promise<ChainValues> { | |
throw new Error("Method not implemented."); | |
} | |
//@ts-ignore | |
apply(inputs: ChainValues[], callbacks?: Callbacks[] | undefined): Promise<ChainValues> { | |
throw new Error("Method not implemented."); | |
} | |
verbose: boolean; | |
callbacks?: Callbacks | undefined; | |
} | |
const tot = new ToTExecutor({ | |
totInput: { | |
sizeLimit: 3, | |
stepLimit: 2, | |
threshold: 0.5 | |
}, | |
evaluator: thoughtEvaluator, | |
generator: thoughtGenerator | |
}) | |
const res = await tot.run("What steps I need to take to become a good AI engineer?"); | |
console.log( | |
JSON.parse(res).sort((a:ToT,b:ToT)=>{ | |
return a.evaluation - b.evaluation; | |
}) | |
); | |
})() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment