Skip to content

Instantly share code, notes, and snippets.

@grahama1970
Last active October 23, 2024 18:21
Show Gist options
  • Save grahama1970/aa9542768b81de017fff13d76b36f922 to your computer and use it in GitHub Desktop.
Save grahama1970/aa9542768b81de017fff13d76b36f922 to your computer and use it in GitHub Desktop.
AranogoDB Integration to LiteLLM rather than Redis
from types import SimpleNamespace
import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm import completion, acompletion, token_counter
import asyncio
from functools import wraps
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from litellm import RateLimitError, APIError
import os
from dotenv import load_dotenv
from loguru import logger
# import spacy
# Load spaCy model
# nlp = spacy.load("en_core_web_md")
from verifaix.arangodb_helper.arango_client import (
generate_safe_key,
setup_arango_client,
connect_to_arango,
ensure_collection_exists,
get_document_by_key,
upsert_document_by_key,
)
load_dotenv('./.env')
# Custom logger to handle cache hit/miss logging
class CacheLogger(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
cache_hit = kwargs.get('cache_hit', False)
if cache_hit:
logger.info(f"Cache hit for call id: '{kwargs.get('litellm_call_id', 'unknown')}'")
else:
logger.info(f"Cache miss for call id: '{kwargs.get('litellm_call_id', 'unknown')}'")
# Set up the custom logger
litellm.callbacks = [CacheLogger()]
def litellm_cache_and_tenacity_retry(
max_retries=3,
wait_multiplier=1,
wait_min=4,
wait_max=10
):
def decorator(func):
@retry(
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=wait_multiplier, min=wait_min, max=wait_max),
retry=retry_if_exception_type((RateLimitError, APIError))
)
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
return decorator
@litellm_cache_and_tenacity_retry()
async def litellm_completion(db, collection_name, **kwargs):
raw_key = kwargs.get("messages")[0]["content"]
cache_key = generate_safe_key(raw_key) # Use the sanitized key
# Directly use the existing function to fetch cached response
cached_response = await asyncio.to_thread(get_document_by_key, db, collection_name, cache_key)
if cached_response:
kwargs['cache_hit'] = True
logger.info(f"Cache hit for prompt: {cache_key}")
response = SimpleNamespace(choices=[{"message": {"content": cached_response.get('value')}}])
return response
# Perform the actual completion request if not cached
start_time = asyncio.get_event_loop().time()
response = await acompletion(**kwargs)
end_time = asyncio.get_event_loop().time()
# Extract the content from the response
response_content = response.choices[0]["message"].get("content")
if not response_content:
logger.error("Received empty content from the completion response.")
raise ValueError("Completion response returned empty content.")
# Use token_counter to analyze the response content
completion_token_count = token_counter(
model=kwargs.get("model"),
messages=[{"role": "assistant", "content": response_content}]
)
# Analyze input messages for prompt tokens
prompt_token_count = token_counter(
model=kwargs.get("model"),
messages=kwargs.get("messages", [])
)
# Total tokens used
total_tokens = prompt_token_count + completion_token_count
# Extract response cost
response_cost = response._hidden_params.get("response_cost")
# Calculate response time in seconds
response_time = end_time - start_time # Time taken for the response in seconds
llm_params = {
"model": kwargs.get("model"),
"temperature": kwargs.get("temperature"),
"max_tokens": kwargs.get("max_tokens"),
}
# Upsert the response along with additional data into cache
await asyncio.to_thread(upsert_document_by_key, db, collection_name, cache_key, {
"value": response.choices[0]["message"]["content"],
"num_tokens": total_tokens,
"prompt_tokens": prompt_token_count,
"completion_tokens": completion_token_count,
"response_time": response_time,
"llm_params": llm_params,
"response_cost": response_cost
})
return response
###
# Usage example
###
async def usage_example():
config = {
"arango_config": {
"host": os.getenv("ARANGO_DB_HOST", "http://localhost:8529"),
"db_name": "verifaix",
"username": os.getenv("ARANGO_DB_USERNAME", "root"),
"password": os.getenv("ARANGO_DB_PASSWORD", "openSesame"),
"collection_name": "litellm_cache"
}
}
# Set up the ArangoDB client and connect to the database
client, arango_config = setup_arango_client(config)
db = connect_to_arango(client, arango_config)
ensure_collection_exists(db, arango_config['collection_name'])
collection_name = arango_config['collection_name']
prompt = "What is the capital of France?"
# First call (should miss cache)
result1 = await litellm_completion(
db, collection_name,
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
caching=True
)
logger.info(f"First call result: {result1.choices[0]['message']['content']}")
# Second call with the same prompt (should hit cache)
result2 = await litellm_completion(
db, collection_name,
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
caching=True
)
logger.info(f"Second call result: {result2.choices[0]['message']['content']}")
# Different prompt (should miss cache)
different_prompt = "What is the capital of Japan?"
result3 = await litellm_completion(
db, collection_name,
model="gpt-4o-mini",
messages=[{"role": "user", "content": different_prompt}]
)
logger.info(f"Different prompt result: {result3.choices[0]['message']['content']}")
async def main():
await usage_example()
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment