Skip to content

Instantly share code, notes, and snippets.

@grahama1970
Created October 23, 2024 13:24
Show Gist options
  • Save grahama1970/a936c9ba0751b5ba63c2f53aeedff4e6 to your computer and use it in GitHub Desktop.
Save grahama1970/a936c9ba0751b5ba63c2f53aeedff4e6 to your computer and use it in GitHub Desktop.
import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm import completion, acompletion, Cache
import asyncio
from functools import wraps
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from litellm import RateLimitError, APIError, ModelResponse
import os
from dotenv import load_dotenv
from loguru import logger
import atexit
load_dotenv('./.env')
# Initialize LiteLLM cache
litellm.cache = Cache(
type="redis",
host=os.environ.get('REDIS_HOST', 'localhost'),
port=int(os.environ.get('REDIS_PORT', 6379)),
password=os.environ.get('REDIS_PASSWORD', 'openSesame')
)
# Custom logger to handle cache hit/miss logging
class CacheLogger(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
cache_hit = kwargs.get('cache_hit', False)
if cache_hit:
logger.info(f"Cache hit for call id: '{kwargs.get('litellm_call_id', 'unknown')}'")
else:
logger.info(f"Cache miss for call id: '{kwargs.get('litellm_call_id', 'unknown')}'")
# Set up the custom logger
litellm.callbacks = [CacheLogger()]
def litellm_cache_and_tenacity_retry(
max_retries=3,
wait_multiplier=1,
wait_min=4,
wait_max=10
):
def decorator(func):
@retry(
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=wait_multiplier, min=wait_min, max=wait_max),
retry=retry_if_exception_type((RateLimitError, APIError))
)
@wraps(func)
async def wrapper(*args, **kwargs):
# kwargs['function_name'] = func.__name__ # Add function name to kwargs for logging
return await func(*args, **kwargs)
return wrapper
return decorator
@litellm_cache_and_tenacity_retry()
async def litellm_completion(**kwargs):
return await acompletion(**kwargs)
# Function to handle closing the cache asynchronously
async def close_cache():
if litellm.cache and hasattr(litellm.cache, '_redis_client'):
await litellm.cache._redis_client.close() # Use the internal redis client if exposed
logger.info("Redis connection closed.")
# Add a close method to ensure proper cleanup
def ensure_async_cache_cleanup():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(close_cache())
# Register the close method to run on exit
atexit.register(ensure_async_cache_cleanup)
async def usage_example():
prompt = "What is the capital of France?"
# First call (should miss cache)
result1 = await litellm_completion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
caching=True
)
logger.info(f"First call result: {result1.choices[0].message.content}")
# Second call with the same prompt (should hit cache)
result2 = await litellm_completion(
model="gpt-4o-mini", # Use the same model to ensure cache hit
messages=[{"role": "user", "content": prompt}],
caching=True
)
logger.info(f"Second call result: {result2.choices[0].message.content}")
# Different prompt (should miss cache)
different_prompt = "What is the capital of Japan?"
result3 = await litellm_completion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": different_prompt}]
)
logger.info(f"Different prompt result: {result3.choices[0].message.content}")
async def main():
await usage_example()
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment