Last active
October 23, 2024 18:21
-
-
Save grahama1970/aa9542768b81de017fff13d76b36f922 to your computer and use it in GitHub Desktop.
AranogoDB Integration to LiteLLM rather than Redis
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from types import SimpleNamespace | |
import litellm | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm import completion, acompletion, token_counter | |
import asyncio | |
from functools import wraps | |
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential | |
from litellm import RateLimitError, APIError | |
import os | |
from dotenv import load_dotenv | |
from loguru import logger | |
# import spacy | |
# Load spaCy model | |
# nlp = spacy.load("en_core_web_md") | |
from verifaix.arangodb_helper.arango_client import ( | |
generate_safe_key, | |
setup_arango_client, | |
connect_to_arango, | |
ensure_collection_exists, | |
get_document_by_key, | |
upsert_document_by_key, | |
) | |
load_dotenv('./.env') | |
# Custom logger to handle cache hit/miss logging | |
class CacheLogger(CustomLogger): | |
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
cache_hit = kwargs.get('cache_hit', False) | |
if cache_hit: | |
logger.info(f"Cache hit for call id: '{kwargs.get('litellm_call_id', 'unknown')}'") | |
else: | |
logger.info(f"Cache miss for call id: '{kwargs.get('litellm_call_id', 'unknown')}'") | |
# Set up the custom logger | |
litellm.callbacks = [CacheLogger()] | |
def litellm_cache_and_tenacity_retry( | |
max_retries=3, | |
wait_multiplier=1, | |
wait_min=4, | |
wait_max=10 | |
): | |
def decorator(func): | |
@retry( | |
stop=stop_after_attempt(max_retries), | |
wait=wait_exponential(multiplier=wait_multiplier, min=wait_min, max=wait_max), | |
retry=retry_if_exception_type((RateLimitError, APIError)) | |
) | |
@wraps(func) | |
async def wrapper(*args, **kwargs): | |
return await func(*args, **kwargs) | |
return wrapper | |
return decorator | |
@litellm_cache_and_tenacity_retry() | |
async def litellm_completion(db, collection_name, **kwargs): | |
raw_key = kwargs.get("messages")[0]["content"] | |
cache_key = generate_safe_key(raw_key) # Use the sanitized key | |
# Directly use the existing function to fetch cached response | |
cached_response = await asyncio.to_thread(get_document_by_key, db, collection_name, cache_key) | |
if cached_response: | |
kwargs['cache_hit'] = True | |
logger.info(f"Cache hit for prompt: {cache_key}") | |
response = SimpleNamespace(choices=[{"message": {"content": cached_response.get('value')}}]) | |
return response | |
# Perform the actual completion request if not cached | |
start_time = asyncio.get_event_loop().time() | |
response = await acompletion(**kwargs) | |
end_time = asyncio.get_event_loop().time() | |
# Extract the content from the response | |
response_content = response.choices[0]["message"].get("content") | |
if not response_content: | |
logger.error("Received empty content from the completion response.") | |
raise ValueError("Completion response returned empty content.") | |
# Use token_counter to analyze the response content | |
completion_token_count = token_counter( | |
model=kwargs.get("model"), | |
messages=[{"role": "assistant", "content": response_content}] | |
) | |
# Analyze input messages for prompt tokens | |
prompt_token_count = token_counter( | |
model=kwargs.get("model"), | |
messages=kwargs.get("messages", []) | |
) | |
# Total tokens used | |
total_tokens = prompt_token_count + completion_token_count | |
# Extract response cost | |
response_cost = response._hidden_params.get("response_cost") | |
# Calculate response time in seconds | |
response_time = end_time - start_time # Time taken for the response in seconds | |
llm_params = { | |
"model": kwargs.get("model"), | |
"temperature": kwargs.get("temperature"), | |
"max_tokens": kwargs.get("max_tokens"), | |
} | |
# Upsert the response along with additional data into cache | |
await asyncio.to_thread(upsert_document_by_key, db, collection_name, cache_key, { | |
"value": response.choices[0]["message"]["content"], | |
"num_tokens": total_tokens, | |
"prompt_tokens": prompt_token_count, | |
"completion_tokens": completion_token_count, | |
"response_time": response_time, | |
"llm_params": llm_params, | |
"response_cost": response_cost | |
}) | |
return response | |
### | |
# Usage example | |
### | |
async def usage_example(): | |
config = { | |
"arango_config": { | |
"host": os.getenv("ARANGO_DB_HOST", "http://localhost:8529"), | |
"db_name": "verifaix", | |
"username": os.getenv("ARANGO_DB_USERNAME", "root"), | |
"password": os.getenv("ARANGO_DB_PASSWORD", "openSesame"), | |
"collection_name": "litellm_cache" | |
} | |
} | |
# Set up the ArangoDB client and connect to the database | |
client, arango_config = setup_arango_client(config) | |
db = connect_to_arango(client, arango_config) | |
ensure_collection_exists(db, arango_config['collection_name']) | |
collection_name = arango_config['collection_name'] | |
prompt = "What is the capital of France?" | |
# First call (should miss cache) | |
result1 = await litellm_completion( | |
db, collection_name, | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": prompt}], | |
caching=True | |
) | |
logger.info(f"First call result: {result1.choices[0]['message']['content']}") | |
# Second call with the same prompt (should hit cache) | |
result2 = await litellm_completion( | |
db, collection_name, | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": prompt}], | |
caching=True | |
) | |
logger.info(f"Second call result: {result2.choices[0]['message']['content']}") | |
# Different prompt (should miss cache) | |
different_prompt = "What is the capital of Japan?" | |
result3 = await litellm_completion( | |
db, collection_name, | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": different_prompt}] | |
) | |
logger.info(f"Different prompt result: {result3.choices[0]['message']['content']}") | |
async def main(): | |
await usage_example() | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment