Created
June 18, 2024 07:53
-
-
Save tubone24/4ef910b152bc1017da36e7d7354e906d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any, Dict, List | |
| from langchain_core.language_models import BaseLanguageModel | |
| from langchain.memory.chat_memory import BaseChatMemory | |
| from langchain_core.messages.ai import AIMessage | |
| from langchain_core.messages.base import BaseMessage | |
| from langchain_core.messages.chat import ChatMessage | |
| from langchain_core.messages.function import FunctionMessage | |
| from langchain_core.messages.human import HumanMessage | |
| from langchain_core.messages.system import SystemMessage | |
| from langchain_core.messages.tool import ToolMessage | |
| class OptimizedClaude3ConversationalBufferMemory(BaseChatMemory): | |
| """ | |
| Claude 3 はXML形式のデータをうまく扱える傾向があるため、ConversationalBufferMemoryを継承してClaude3用に最適化したクラス | |
| Ref: https://docs.anthropic.com/claude/docs/use-xml-tags | |
| """ | |
| human_prefix: str = "<Human>" | |
| human_suffix: str = "</Human>" | |
| ai_prefix: str = "<AI>" | |
| ai_suffix: str = "</AI>" | |
| system_prefix: str = "<System>" | |
| system_suffix: str = "</System>" | |
| function_prefix: str = "<Function>" | |
| function_suffix: str = "</Function>" | |
| tool_prefix: str = "<Tool>" | |
| tool_suffix: str = "</Tool>" | |
| llm: BaseLanguageModel | |
| memory_key: str = "history" | |
| max_token_limit: int = 2000 | |
| @property | |
| def buffer(self) -> Any: | |
| """String buffer of memory.""" | |
| return self.buffer_as_messages if self.return_messages else self.buffer_as_str | |
| @property | |
| def buffer_as_str(self) -> str: | |
| """Exposes the buffer as a string in case return_messages is False.""" | |
| string_messages = [] | |
| for m in self.chat_memory.messages: | |
| if isinstance(m, HumanMessage): | |
| prefix = self.human_prefix | |
| suffix = self.human_suffix | |
| elif isinstance(m, AIMessage): | |
| prefix = self.ai_prefix | |
| suffix = self.ai_suffix | |
| elif isinstance(m, SystemMessage): | |
| prefix = self.system_prefix | |
| suffix = self.system_suffix | |
| elif isinstance(m, FunctionMessage): | |
| prefix = self.function_prefix | |
| suffix = self.function_suffix | |
| elif isinstance(m, ToolMessage): | |
| prefix = self.tool_prefix | |
| suffix = self.tool_suffix | |
| elif isinstance(m, ChatMessage): | |
| prefix = f"<{m.role}>" | |
| suffix = f"</{m.role}>" | |
| else: | |
| raise ValueError(f"Got unsupported message type: {m}") | |
| message = f"{prefix}{m.content}{suffix}" | |
| if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs: | |
| message += f"{m.additional_kwargs['function_call']}" | |
| string_messages.append(message) | |
| return "\n".join(string_messages) | |
| @property | |
| def buffer_as_messages(self) -> List[BaseMessage]: | |
| """Exposes the buffer as a list of messages in case return_messages is True.""" | |
| return self.chat_memory.messages | |
| @property | |
| def memory_variables(self) -> List[str]: | |
| """Will always return list of memory variables. | |
| :meta private: | |
| """ | |
| return [self.memory_key] | |
| def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |
| """Return history buffer.""" | |
| return {self.memory_key: self.buffer} | |
| def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: | |
| """Save context from this conversation to buffer. Pruned.""" | |
| super().save_context(inputs, outputs) | |
| # Prune buffer if it exceeds max token limit | |
| buffer = self.chat_memory.messages | |
| curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) | |
| if curr_buffer_length > self.max_token_limit: | |
| pruned_memory = [] | |
| while curr_buffer_length > self.max_token_limit: | |
| pruned_memory.append(buffer.pop(0)) | |
| curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment