Last active
October 2, 2024 17:04
-
-
Save vilsonrodrigues/4fbca7c7fd3923ceddc2cc738f3736ea to your computer and use it in GitHub Desktop.
A multimodal chat history multi user implementation in python. ChatHistory can manage chat duration based-on time. Implement reranker methods to search relevant messages from users.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import math | |
import re | |
from collections import defaultdict | |
from datetime import datetime, timedelta | |
from typing import Callable, Dict, List, Optional, Union | |
class BasicDBConnector: | |
def __init__(self): | |
self.storage = {} | |
def save(self, user_id: str, history): | |
if user_id in self.storage: | |
self.storage[user_id].extend(history) | |
else: | |
self.storage[user_id] = history | |
def load(self, user_id): | |
return self.storage.get(user_id, []) | |
def _convert_time_str_to_seconds(time_str: str) -> int: | |
""" Convert time str in secounds. | |
Args: | |
time_str: | |
Time string in format '2h', '10m', '50s'. | |
Returns: | |
Time in secounds. | |
""" | |
if isinstance(time_str, str): | |
match = re.match(r"(\d+)([hms])", time_str) | |
if not match: | |
raise ValueError("Invalid format. Use 'h' for hours, 'm' for minutes, or 's' for seconds.") | |
value, unit = int(match.group(1)), match.group(2) | |
if unit == 'h': | |
return value * 3600 | |
elif unit == 'm': | |
return value * 60 | |
elif unit == 's': | |
return value | |
else: | |
raise ValueError("time_str must be a string.") | |
def _bm25( | |
documents: list[str], | |
query: str, | |
top_k: Optional[int] = None, | |
k1: Optional[float] = 1.5, | |
b: Optional[float] = 0.75 | |
) -> List[str]: | |
""" BM25 (TF-IDF-like) reranker algorithm implementation | |
Args: | |
documents: | |
A document list to search. | |
query: | |
Query to search. | |
top_k: | |
Top-k similarity documents to return. | |
k1: | |
BM25 tuning parameter. | |
b: | |
Length normalization parameter in BM25. | |
Returns: | |
A document list classified by relevance. | |
""" | |
tokenized_docs = [doc.lower().split() for doc in documents] | |
tokenized_query = query.lower().split() | |
# Calculation of frequencies and inverse document frequency (IDF) | |
doc_freqs = defaultdict(int) | |
for doc in tokenized_docs: | |
for word in set(doc): | |
doc_freqs[word] += 1 | |
N = len(tokenized_docs) | |
avg_doc_len = sum(len(doc) for doc in tokenized_docs) / N | |
# Calculate the BM25 score of a single document | |
def bm25_score(doc, query_words): | |
doc_len = len(doc) | |
score = 0.0 | |
for word in query_words: | |
if word in doc: | |
f = doc.count(word) | |
df = doc_freqs.get(word, 0) | |
idf = math.log(1 + (N - df + 0.5) / (df + 0.5)) | |
numerator = f * (k1 + 1) | |
denominator = f + k1 * (1 - b + b * (doc_len / avg_doc_len)) | |
score += idf * (numerator / denominator) | |
return score | |
scores = [(doc, bm25_score(doc, tokenized_query)) for doc in tokenized_docs] | |
ranked_docs = sorted(scores, key=lambda x: x[1], reverse=True) | |
ordered_ranked_docs = [documents[idx] for idx, _ in enumerate(ranked_docs)] | |
return ordered_ranked_docs[-top_k:] if top_k else ordered_ranked_docs | |
class ChatHistory: | |
""" Class responsible for managing chat history. | |
ChatHistory is capable of loading and saving conversations in a database. | |
It implements a self-sanitization system based on temporal flow. | |
On-delete the current user chat state save on db just new messages. | |
ChatHistory supports implementation of Reranker algorithms in user conversations. | |
The default algorithm is a simple implementation of BM25. | |
""" | |
chat_history: Dict[str, Union[str, List[Dict[str, str]]]] = {} # User chat history | |
last_modified: Dict[str, datetime] = {} # Stores the last modification timestamp by user_id | |
last_saved_index: Dict[str, int] = {} # Keeps track of the last saved message index | |
def __init__( | |
self, | |
expiration_time: Optional[str] = None, | |
sanitize_interval: Optional[str] = "1h", | |
connector: Optional[Callable] = None, | |
reranker_method: Optional[Callable] = _bm25, | |
save_on_delete: Optional[bool] = False, | |
): | |
""" | |
Args: | |
expiration_time: Time to remove inactive history ('2h', '10m', '50s'). | |
sanitize_interval: Interval for the cleaning routine ('2h', '10m', '50s'). | |
connector: Database connector instance. | |
reranker_method: Method to search user relevant conversations. | |
save_on_delete: If True, save history to DB before deletion. | |
""" | |
self.connector = connector | |
self.reranker_method = reranker_method | |
self.save_on_delete = save_on_delete | |
self.sanitize_interval = _convert_time_str_to_seconds(sanitize_interval) | |
self.expiration_time = (_convert_time_str_to_seconds(expiration_time) | |
if expiration_time else None) | |
if self.expiration_time is not None: | |
asyncio.create_task(self.start_cleanup_routine(self.sanitize_interval)) | |
def add( | |
self, | |
user_id: str, | |
role: str, | |
content: str, | |
audios: Optional[List[str]] = None, | |
images: Optional[List[str]] = None, | |
): | |
""" Adds a message to a user's history and updates the modification timestamp. | |
Args: | |
user_id: | |
Message User-ID. | |
role: | |
user or assistant. | |
content: | |
Message content. | |
audios: | |
List of audio URLs (if applicable for multimodal input). | |
images: | |
List of image URLs (if applicable for multimodal input). | |
""" | |
if user_id not in self.chat_history: | |
self.chat_history[user_id] = [] | |
if audios or images: | |
# Multimodal content | |
multimodal_content = [] | |
if content: | |
multimodal_content.append({"type": "text", "text": content}) | |
if images: | |
for image_url in images: | |
multimodal_content.append({"type": "image_url", "image_url": {"url": image_url}}) | |
if audios: | |
for audio_url in audios: | |
multimodal_content.append({"type": "audio_url", "audio_url": {"url": audio_url}}) | |
self.chat_history[user_id].append({"role": role, "content": multimodal_content}) | |
else: | |
# Simple text content | |
self.chat_history[user_id].append({"role": role, "content": content}) | |
self.last_modified[user_id] = datetime.now() | |
def get_history(self, user_id: str, top_k: int = None) -> List[Dict[str, str]]: | |
""" Returns a user's chat history. | |
Args: | |
user_id: | |
User-ID to retrieval chat history. | |
top_k: | |
The number of most recent messages to return. | |
If None, returns the entire history. | |
Returns: | |
List of dictionaries containing message history in chat template format. | |
""" | |
if user_id not in self.chat_history: | |
return [] | |
history = self.chat_history[user_id] | |
return history[-top_k:] if top_k else history | |
def count_messages(self, user_id: str) -> int: | |
""" Returns the count of messages for a given user.""" | |
return len(self.get_history(user_id)) | |
def remove_user(self, user_id: str): | |
""" Removes the user's history from the current state.""" | |
if user_id in self.chat_history: | |
if self.save_on_delete and self.connector: | |
self.save_to_db(user_id) # Save before deleting | |
del self.chat_history[user_id] | |
if user_id in self.last_modified: | |
del self.last_modified[user_id] | |
if user_id in self.last_saved_index: | |
del self.last_saved_index[user_id] | |
def save_to_db(self, user_id: str): | |
""" Saves a user's history to the database.""" | |
if self.connector: | |
history = self.get_history(user_id) | |
last_saved = self.last_saved_index.get(user_id, -1) | |
# Save only new messages after the last saved index | |
unsaved_messages = history[last_saved + 1:] # Get only the unsaved portion | |
if unsaved_messages: | |
self.connector.save(user_id, unsaved_messages) | |
# Update the last saved index after saving | |
self.last_saved_index[user_id] = len(history) - 1 | |
def load_from_db(self, user_id: str): | |
""" Loads a user's history from the database.""" | |
if self.connector: | |
history = self.connector.load(user_id) | |
if history: | |
self.chat_history[user_id] = history | |
self.last_saved_index[user_id] = len(history) - 1 # Mark all as saved | |
self.last_modified[user_id] = datetime.now() | |
def remove_inactive_histories(self): | |
""" Removes inactive histories.""" | |
current_time = datetime.now() | |
expiration_delta = timedelta(seconds=self.expiration_time) | |
inactive_users = ([user_id for user_id, last_mod in self.last_modified.items() | |
if current_time - last_mod > expiration_delta]) | |
for user_id in inactive_users: | |
self.remove_user(user_id) | |
async def start_cleanup_routine(self, sanitize_interval: int): | |
""" Starts a cleanup routine to remove inactive histories periodically.""" | |
while True: | |
await asyncio.sleep(sanitize_interval) | |
self.remove_inactive_histories() | |
def search_messages(self, user_id, query, top_k: int = None) -> List[str]: | |
""" Fetches messages from a user based on the specified retrieval method. | |
Args: | |
user_id: | |
User-ID to fetch. | |
query: | |
Query to search in user messages. | |
top_k: | |
The number of the user's most relevant messages to return. | |
If None, returns the entire history. | |
Returns: | |
List of messages sorted by relevance. | |
""" | |
if user_id not in self.chat_history: | |
return [] | |
# Extract the textual content from user messages, handling both simple text and multimodal content | |
user_messages = [] | |
for msg in self.chat_history[user_id]: | |
if msg['role'] == 'user': | |
content = msg['content'] | |
if isinstance(content, str): | |
# Simple text message | |
user_messages.append(content) | |
elif isinstance(content, list): | |
# Multimodal content: extract where 'type' is 'text' | |
for item in content: | |
if item.get('type') == 'text': | |
user_messages.append(item['text']) | |
return self.reranker_method(user_messages, query, top_k) | |
if __name__ == "__main__": | |
connector = BasicDBConnector() | |
chat_history = ChatHistory( | |
expiration_time='60s', | |
sanitize_interval='10s', | |
connector=connector, | |
save_on_delete=True, | |
) | |
chat_history.add('user1', 'user', 'Hi, I need help.') | |
chat_history.add('user1', 'assistant', 'Sure, what do you need?') | |
chat_history.save_to_db('user1') | |
chat_history.add( | |
user_id="user1", | |
role="user", | |
content="Can you identify the objects in these media?", | |
images=["https://image1.jpg", "https://image2.jpg"], | |
audios=["https://audio1.mp3"] | |
) | |
chat_history.save_to_db('user1') | |
chat_history.load_from_db('user1') | |
print(f"Num messages in history: {chat_history.count_messages('user1')} \n") | |
print(f"DB user message: \n {connector.load('user1')} \n") | |
print(f"User current state: \n {chat_history.get_history('user1')} \n") | |
print(f"Ranked messages: \n {chat_history.search_messages('user1', 'objects', 1)} \n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment