Skip to content

Instantly share code, notes, and snippets.

@tubone24
Created June 18, 2024 07:53
Show Gist options
  • Save tubone24/4ef910b152bc1017da36e7d7354e906d to your computer and use it in GitHub Desktop.
Save tubone24/4ef910b152bc1017da36e7d7354e906d to your computer and use it in GitHub Desktop.
from typing import Any, Dict, List
from langchain_core.language_models import BaseLanguageModel
from langchain.memory.chat_memory import BaseChatMemory
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.messages.chat import ChatMessage
from langchain_core.messages.function import FunctionMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.system import SystemMessage
from langchain_core.messages.tool import ToolMessage
class OptimizedClaude3ConversationalBufferMemory(BaseChatMemory):
"""
Claude 3 はXML形式のデータをうまく扱える傾向があるため、ConversationalBufferMemoryを継承してClaude3用に最適化したクラス
Ref: https://docs.anthropic.com/claude/docs/use-xml-tags
"""
human_prefix: str = "<Human>"
human_suffix: str = "</Human>"
ai_prefix: str = "<AI>"
ai_suffix: str = "</AI>"
system_prefix: str = "<System>"
system_suffix: str = "</System>"
function_prefix: str = "<Function>"
function_suffix: str = "</Function>"
tool_prefix: str = "<Tool>"
tool_suffix: str = "</Tool>"
llm: BaseLanguageModel
memory_key: str = "history"
max_token_limit: int = 2000
@property
def buffer(self) -> Any:
"""String buffer of memory."""
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is False."""
string_messages = []
for m in self.chat_memory.messages:
if isinstance(m, HumanMessage):
prefix = self.human_prefix
suffix = self.human_suffix
elif isinstance(m, AIMessage):
prefix = self.ai_prefix
suffix = self.ai_suffix
elif isinstance(m, SystemMessage):
prefix = self.system_prefix
suffix = self.system_suffix
elif isinstance(m, FunctionMessage):
prefix = self.function_prefix
suffix = self.function_suffix
elif isinstance(m, ToolMessage):
prefix = self.tool_prefix
suffix = self.tool_suffix
elif isinstance(m, ChatMessage):
prefix = f"<{m.role}>"
suffix = f"</{m.role}>"
else:
raise ValueError(f"Got unsupported message type: {m}")
message = f"{prefix}{m.content}{suffix}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
string_messages.append(message)
return "\n".join(string_messages)
@property
def buffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is True."""
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer. Pruned."""
super().save_context(inputs, outputs)
# Prune buffer if it exceeds max token limit
buffer = self.chat_memory.messages
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit:
pruned_memory.append(buffer.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment