Last active
February 4, 2024 16:45
-
-
Save koduki/d60410f8db7b94f92ab4c01187b61300 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
os.environ["OPENAI_API_KEY"] = open(f"{os.environ['HOMEPATH']}\\.secret\\openai.txt", "r").read() | |
os.environ["GOOGLE_API_KEY"] = open(f"{os.environ['HOMEPATH']}\\.secret\\gemini.txt", "r").read() | |
from backend import weather_tool | |
from backend import short_talk_tool | |
from langchain.memory import ConversationBufferMemory | |
from langchain.schema.agent import AgentFinish | |
from langchain.tools.render import format_tool_to_openai_function | |
from langchain.agents.format_scratchpad import format_to_openai_functions | |
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser | |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
from langchain_openai import ChatOpenAI | |
from operator import itemgetter | |
# 出力フォーマットを定義 | |
from langchain_core.output_parsers import JsonOutputParser | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
class Reply(BaseModel): | |
current_emotion: str = Field(description="maxe") | |
character_reply: str = Field(description="れん's reply to User") | |
parser = JsonOutputParser(pydantic_object=Reply) | |
prompt_system = open("C:\\Users\\koduki\\git\\ai-tuber\\src\\backend\\prompt_system.txt", "r", encoding='utf-8').read() | |
prompt_for_chat = ChatPromptTemplate.from_messages([ | |
("system", prompt_system), | |
("user", "{input}"), | |
MessagesPlaceholder(variable_name="chat_history"), | |
MessagesPlaceholder(variable_name="scratchpad"), | |
]).partial(format_instructions=parser.get_format_instructions()) | |
prompt_for_tools = ChatPromptTemplate.from_messages([ | |
("system", "You are agentai"), | |
("user", "{input}"), | |
]) | |
tools = [weather_tool.weather_api, short_talk_tool.talk] | |
llm_for_chat = ChatOpenAI(temperature=0, model='gpt-4-0613') | |
llm_with_tools = ChatOpenAI(temperature=0, model='gpt-3.5-turbo').bind(functions=[format_tool_to_openai_function(t) for t in tools]) | |
def call_func(log): | |
if isinstance(log, AgentFinish): | |
return [(log, [])] | |
else: | |
tool = next(x for x in tools if x.name == log.tool) | |
observation = tool.run(log.tool_input) | |
return [(log, observation)] | |
def store_memory(response): | |
input = {"input":response["input"]} | |
output = {"output": response["return_values"].return_values['output']} | |
memory.save_context(input, output) | |
return output | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
router = ( | |
RunnablePassthrough().assign( | |
chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("chat_history"), | |
scratchpad=prompt_for_tools | llm_with_tools | OpenAIFunctionsAgentOutputParser() | call_func | format_to_openai_functions | |
)| RunnablePassthrough().assign( | |
return_values=prompt_for_chat | llm_for_chat | OpenAIFunctionsAgentOutputParser(), | |
)| store_memory | |
) | |
router.invoke({ | |
"input": "今日の東京の天気は?", | |
}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment