Last active
September 6, 2024 11:54
-
-
Save bamboriz/2fe2624e901d579ecec40af7187ab7e8 to your computer and use it in GitHub Desktop.
Say goodbye to boring Chatbots by combining Structure (Chatbot Frameworks) & Flexibility (LLMs)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import operator | |
from pprint import pprint | |
from typing import Annotated, Optional, Sequence, TypedDict, List | |
from langchain.tools import tool | |
from langchain.tools.render import format_tool_to_openai_function | |
from langchain_core.load import dumps, loads | |
from langchain_core.messages import ( | |
AIMessage, | |
BaseMessage, | |
FunctionMessage, | |
HumanMessage, | |
SystemMessage, | |
) | |
from langchain_openai import ChatOpenAI | |
from langgraph.graph import END, StateGraph | |
from langgraph.prebuilt import ToolExecutor, ToolInvocation | |
from functions.cache.cache_utils import get_llm_memory, store_llm_memory | |
from functions.lexbot import LexBot | |
# Type definitions | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], operator.add] | |
# Constants | |
DECISION_AND_RESPONSE_PROMPT = """ | |
You are a smart orchestrator facilitating communication between a user and a Lex bot. Do not make any reference of the bot to the user. | |
Follow these guidelines: | |
Clarifying Questions: | |
If the Lex bot asks the user a question, and the user responds with a clarifying question, address the user's query directly. | |
DO NOT RESPOND TO ANY QUESTIONS FROM THE USER THAT ARE NOT RELATED TO THE CONVERSATION. Politely tell them you have no idea. | |
""" | |
# Global variables | |
bot: Optional[LexBot] = None | |
model = ChatOpenAI(temperature=0, streaming=True, model_name="gpt-4o") | |
def filter_messages(messages: List[BaseMessage]) -> List[BaseMessage]: | |
if len(messages) <= 10: | |
return messages | |
return messages[:1] + messages[-10:] | |
# Tool definitions | |
@tool("lex_bot", return_direct=True) | |
def lex_bot(text: str, session_id: Optional[str] = None) -> str: | |
"""Returns the bot prompt.""" | |
next_prompt = bot.get_bot_response(text, session_id) | |
return next_prompt | |
# Setup | |
tools = [lex_bot] | |
functions = [format_tool_to_openai_function(t) for t in tools] | |
MODEL = model.bind_functions(functions) | |
TOOL_EXECUTOR = ToolExecutor(tools) | |
# Helper functions | |
def should_continue(state: AgentState) -> str: | |
messages = state["messages"] | |
last_message = messages[-1] | |
return "continue" if "function_call" in last_message.additional_kwargs else "end" | |
def call_model(state: AgentState) -> AgentState: | |
messages = state["messages"] | |
filtered_messages = filter_messages(messages) | |
human_input = filtered_messages[-1].content | |
if "hello code" in human_input.lower(): | |
# Force bot call for first message | |
return { | |
"messages": [ | |
AIMessage( | |
content="", | |
additional_kwargs={ | |
"function_call": { | |
"arguments": f'{{"text":"{human_input}"}}', | |
"name": "lex_bot", | |
} | |
}, | |
) | |
] | |
} | |
response = MODEL.invoke(filtered_messages) | |
return {"messages": [response]} | |
def call_tool(state: AgentState, config: dict) -> AgentState: | |
messages = state["messages"] | |
last_message = messages[-1] | |
session_id = config["configurable"]["session_id"] | |
arguments = json.loads(last_message.additional_kwargs["function_call"]["arguments"]) | |
arguments.update({"session_id": session_id}) | |
print(f"Calling Bot with {arguments}") | |
action = ToolInvocation( | |
tool=last_message.additional_kwargs["function_call"]["name"], | |
tool_input=arguments, | |
) | |
print(f"The agent action is {action}") | |
response = TOOL_EXECUTOR.invoke(action) | |
message_response = {key: response[key] for key in response if key == "messages"} | |
print(f"The tool result is: {message_response}") | |
function_message = FunctionMessage(content=json.dumps(message_response), name=action.tool) | |
return {"messages": [function_message]} | |
def llm_bot(llm_bot_input: str, user_id: str, lex_bot: LexBot) -> dict: | |
global bot | |
bot = lex_bot | |
workflow = StateGraph(AgentState) | |
workflow.add_node("agent", call_model) | |
workflow.add_node("action", call_tool) | |
workflow.set_entry_point("agent") | |
workflow.add_conditional_edges( | |
"agent", should_continue, {"continue": "action", "end": END} | |
) | |
workflow.add_edge("action", END) | |
app = workflow.compile() | |
system_message = SystemMessage(content=DECISION_AND_RESPONSE_PROMPT) | |
user_input = HumanMessage(content=llm_bot_input) | |
messages = get_llm_memory(user_id) | |
inputs = { | |
"messages": loads(messages) + [user_input] | |
if messages | |
else [system_message, user_input] | |
} | |
output = app.invoke(inputs, {"configurable": {"session_id": user_id}}) | |
print("STATE:") | |
pprint(output) | |
print("END STATE") | |
store_llm_memory(dumps(output.get("messages")), user_id) | |
if output.get("messages")[-1].type == "function": | |
return json.loads(output.get("messages")[-1].content) | |
return {"messages": [{"content": output.get("messages")[-1].content}]} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment