Skip to content

Instantly share code, notes, and snippets.

@vilsonrodrigues
Last active October 2, 2024 17:04
Show Gist options
  • Save vilsonrodrigues/4fbca7c7fd3923ceddc2cc738f3736ea to your computer and use it in GitHub Desktop.
Save vilsonrodrigues/4fbca7c7fd3923ceddc2cc738f3736ea to your computer and use it in GitHub Desktop.
A multimodal chat history multi user implementation in python. ChatHistory can manage chat duration based-on time. Implement reranker methods to search relevant messages from users.
import asyncio
import math
import re
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Callable, Dict, List, Optional, Union
class BasicDBConnector:
def __init__(self):
self.storage = {}
def save(self, user_id: str, history):
if user_id in self.storage:
self.storage[user_id].extend(history)
else:
self.storage[user_id] = history
def load(self, user_id):
return self.storage.get(user_id, [])
def _convert_time_str_to_seconds(time_str: str) -> int:
""" Convert time str in secounds.
Args:
time_str:
Time string in format '2h', '10m', '50s'.
Returns:
Time in secounds.
"""
if isinstance(time_str, str):
match = re.match(r"(\d+)([hms])", time_str)
if not match:
raise ValueError("Invalid format. Use 'h' for hours, 'm' for minutes, or 's' for seconds.")
value, unit = int(match.group(1)), match.group(2)
if unit == 'h':
return value * 3600
elif unit == 'm':
return value * 60
elif unit == 's':
return value
else:
raise ValueError("time_str must be a string.")
def _bm25(
documents: list[str],
query: str,
top_k: Optional[int] = None,
k1: Optional[float] = 1.5,
b: Optional[float] = 0.75
) -> List[str]:
""" BM25 (TF-IDF-like) reranker algorithm implementation
Args:
documents:
A document list to search.
query:
Query to search.
top_k:
Top-k similarity documents to return.
k1:
BM25 tuning parameter.
b:
Length normalization parameter in BM25.
Returns:
A document list classified by relevance.
"""
tokenized_docs = [doc.lower().split() for doc in documents]
tokenized_query = query.lower().split()
# Calculation of frequencies and inverse document frequency (IDF)
doc_freqs = defaultdict(int)
for doc in tokenized_docs:
for word in set(doc):
doc_freqs[word] += 1
N = len(tokenized_docs)
avg_doc_len = sum(len(doc) for doc in tokenized_docs) / N
# Calculate the BM25 score of a single document
def bm25_score(doc, query_words):
doc_len = len(doc)
score = 0.0
for word in query_words:
if word in doc:
f = doc.count(word)
df = doc_freqs.get(word, 0)
idf = math.log(1 + (N - df + 0.5) / (df + 0.5))
numerator = f * (k1 + 1)
denominator = f + k1 * (1 - b + b * (doc_len / avg_doc_len))
score += idf * (numerator / denominator)
return score
scores = [(doc, bm25_score(doc, tokenized_query)) for doc in tokenized_docs]
ranked_docs = sorted(scores, key=lambda x: x[1], reverse=True)
ordered_ranked_docs = [documents[idx] for idx, _ in enumerate(ranked_docs)]
return ordered_ranked_docs[-top_k:] if top_k else ordered_ranked_docs
class ChatHistory:
""" Class responsible for managing chat history.
ChatHistory is capable of loading and saving conversations in a database.
It implements a self-sanitization system based on temporal flow.
On-delete the current user chat state save on db just new messages.
ChatHistory supports implementation of Reranker algorithms in user conversations.
The default algorithm is a simple implementation of BM25.
"""
chat_history: Dict[str, Union[str, List[Dict[str, str]]]] = {} # User chat history
last_modified: Dict[str, datetime] = {} # Stores the last modification timestamp by user_id
last_saved_index: Dict[str, int] = {} # Keeps track of the last saved message index
def __init__(
self,
expiration_time: Optional[str] = None,
sanitize_interval: Optional[str] = "1h",
connector: Optional[Callable] = None,
reranker_method: Optional[Callable] = _bm25,
save_on_delete: Optional[bool] = False,
):
"""
Args:
expiration_time: Time to remove inactive history ('2h', '10m', '50s').
sanitize_interval: Interval for the cleaning routine ('2h', '10m', '50s').
connector: Database connector instance.
reranker_method: Method to search user relevant conversations.
save_on_delete: If True, save history to DB before deletion.
"""
self.connector = connector
self.reranker_method = reranker_method
self.save_on_delete = save_on_delete
self.sanitize_interval = _convert_time_str_to_seconds(sanitize_interval)
self.expiration_time = (_convert_time_str_to_seconds(expiration_time)
if expiration_time else None)
if self.expiration_time is not None:
asyncio.create_task(self.start_cleanup_routine(self.sanitize_interval))
def add(
self,
user_id: str,
role: str,
content: str,
audios: Optional[List[str]] = None,
images: Optional[List[str]] = None,
):
""" Adds a message to a user's history and updates the modification timestamp.
Args:
user_id:
Message User-ID.
role:
user or assistant.
content:
Message content.
audios:
List of audio URLs (if applicable for multimodal input).
images:
List of image URLs (if applicable for multimodal input).
"""
if user_id not in self.chat_history:
self.chat_history[user_id] = []
if audios or images:
# Multimodal content
multimodal_content = []
if content:
multimodal_content.append({"type": "text", "text": content})
if images:
for image_url in images:
multimodal_content.append({"type": "image_url", "image_url": {"url": image_url}})
if audios:
for audio_url in audios:
multimodal_content.append({"type": "audio_url", "audio_url": {"url": audio_url}})
self.chat_history[user_id].append({"role": role, "content": multimodal_content})
else:
# Simple text content
self.chat_history[user_id].append({"role": role, "content": content})
self.last_modified[user_id] = datetime.now()
def get_history(self, user_id: str, top_k: int = None) -> List[Dict[str, str]]:
""" Returns a user's chat history.
Args:
user_id:
User-ID to retrieval chat history.
top_k:
The number of most recent messages to return.
If None, returns the entire history.
Returns:
List of dictionaries containing message history in chat template format.
"""
if user_id not in self.chat_history:
return []
history = self.chat_history[user_id]
return history[-top_k:] if top_k else history
def count_messages(self, user_id: str) -> int:
""" Returns the count of messages for a given user."""
return len(self.get_history(user_id))
def remove_user(self, user_id: str):
""" Removes the user's history from the current state."""
if user_id in self.chat_history:
if self.save_on_delete and self.connector:
self.save_to_db(user_id) # Save before deleting
del self.chat_history[user_id]
if user_id in self.last_modified:
del self.last_modified[user_id]
if user_id in self.last_saved_index:
del self.last_saved_index[user_id]
def save_to_db(self, user_id: str):
""" Saves a user's history to the database."""
if self.connector:
history = self.get_history(user_id)
last_saved = self.last_saved_index.get(user_id, -1)
# Save only new messages after the last saved index
unsaved_messages = history[last_saved + 1:] # Get only the unsaved portion
if unsaved_messages:
self.connector.save(user_id, unsaved_messages)
# Update the last saved index after saving
self.last_saved_index[user_id] = len(history) - 1
def load_from_db(self, user_id: str):
""" Loads a user's history from the database."""
if self.connector:
history = self.connector.load(user_id)
if history:
self.chat_history[user_id] = history
self.last_saved_index[user_id] = len(history) - 1 # Mark all as saved
self.last_modified[user_id] = datetime.now()
def remove_inactive_histories(self):
""" Removes inactive histories."""
current_time = datetime.now()
expiration_delta = timedelta(seconds=self.expiration_time)
inactive_users = ([user_id for user_id, last_mod in self.last_modified.items()
if current_time - last_mod > expiration_delta])
for user_id in inactive_users:
self.remove_user(user_id)
async def start_cleanup_routine(self, sanitize_interval: int):
""" Starts a cleanup routine to remove inactive histories periodically."""
while True:
await asyncio.sleep(sanitize_interval)
self.remove_inactive_histories()
def search_messages(self, user_id, query, top_k: int = None) -> List[str]:
""" Fetches messages from a user based on the specified retrieval method.
Args:
user_id:
User-ID to fetch.
query:
Query to search in user messages.
top_k:
The number of the user's most relevant messages to return.
If None, returns the entire history.
Returns:
List of messages sorted by relevance.
"""
if user_id not in self.chat_history:
return []
# Extract the textual content from user messages, handling both simple text and multimodal content
user_messages = []
for msg in self.chat_history[user_id]:
if msg['role'] == 'user':
content = msg['content']
if isinstance(content, str):
# Simple text message
user_messages.append(content)
elif isinstance(content, list):
# Multimodal content: extract where 'type' is 'text'
for item in content:
if item.get('type') == 'text':
user_messages.append(item['text'])
return self.reranker_method(user_messages, query, top_k)
if __name__ == "__main__":
connector = BasicDBConnector()
chat_history = ChatHistory(
expiration_time='60s',
sanitize_interval='10s',
connector=connector,
save_on_delete=True,
)
chat_history.add('user1', 'user', 'Hi, I need help.')
chat_history.add('user1', 'assistant', 'Sure, what do you need?')
chat_history.save_to_db('user1')
chat_history.add(
user_id="user1",
role="user",
content="Can you identify the objects in these media?",
images=["https://image1.jpg", "https://image2.jpg"],
audios=["https://audio1.mp3"]
)
chat_history.save_to_db('user1')
chat_history.load_from_db('user1')
print(f"Num messages in history: {chat_history.count_messages('user1')} \n")
print(f"DB user message: \n {connector.load('user1')} \n")
print(f"User current state: \n {chat_history.get_history('user1')} \n")
print(f"Ranked messages: \n {chat_history.search_messages('user1', 'objects', 1)} \n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment