Skip to content

Instantly share code, notes, and snippets.

@grahama1970
Created October 23, 2024 15:00
Show Gist options
  • Save grahama1970/4975d648d7a96f8e560221fbb6f411a1 to your computer and use it in GitHub Desktop.
Save grahama1970/4975d648d7a96f8e560221fbb6f411a1 to your computer and use it in GitHub Desktop.
Added a Different Redis implementation to LiteLLM
from types import SimpleNamespace
import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm import completion, acompletion
import asyncio
from functools import wraps
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from litellm import RateLimitError, APIError
import os
from dotenv import load_dotenv
from loguru import logger
import aioredis
import atexit
load_dotenv('./.env')
class RedisCache:
def __init__(self, host='localhost', port=6379, password=None):
self.host = host
self.port = port
self.password = password
self.client = None
async def connect(self):
self.client = await aioredis.from_url(
f"redis://{self.host}:{self.port}",
password=self.password
)
# Register automatic cleanup
atexit.register(self._close_sync)
async def get(self, key):
if self.client:
return await self.client.get(key)
async def set(self, key, value):
if self.client:
await self.client.set(key, value)
async def close(self):
if self.client:
await self.client.close()
logger.info("Redis connection closed.")
def _close_sync(self):
"""Ensure Redis closes automatically on exit, even in non-async contexts."""
if self.client:
asyncio.run(self.close())
# Initialize Redis Cache
redis_cache = RedisCache(
host=os.environ.get('REDIS_HOST', 'localhost'),
port=int(os.environ.get('REDIS_PORT', 6379)),
password=os.environ.get('REDIS_PASSWORD', 'openSesame')
)
# Custom logger to handle cache hit/miss logging
class CacheLogger(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
cache_hit = kwargs.get('cache_hit', False)
if cache_hit:
logger.info(f"Cache hit for call id: '{kwargs.get('litellm_call_id', 'unknown')}'")
else:
logger.info(f"Cache miss for call id: '{kwargs.get('litellm_call_id', 'unknown')}'")
# Set up the custom logger
litellm.callbacks = [CacheLogger()]
def litellm_cache_and_tenacity_retry(
max_retries=3,
wait_multiplier=1,
wait_min=4,
wait_max=10
):
def decorator(func):
@retry(
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=wait_multiplier, min=wait_min, max=wait_max),
retry=retry_if_exception_type((RateLimitError, APIError))
)
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
return decorator
@litellm_cache_and_tenacity_retry()
async def litellm_completion(**kwargs):
cache_key = kwargs.get("messages")[0]["content"]
cached_response = await redis_cache.get(cache_key)
if cached_response:
kwargs['cache_hit'] = True
logger.info(f"Cache hit for prompt: {cache_key}")
response = SimpleNamespace(choices=[{"message": {"content": cached_response.decode('utf-8')}}])
return response
# Perform the actual completion request if not cached
response = await acompletion(**kwargs)
await redis_cache.set(cache_key, response.choices[0]["message"]["content"])
return response
async def usage_example():
prompt = "What is the capital of France?"
# First call (should miss cache)
result1 = await litellm_completion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
caching=True
)
logger.info(f"First call result: {result1.choices[0]['message']['content']}")
# Second call with the same prompt (should hit cache)
result2 = await litellm_completion(
model="gpt-4o-mini", # Use the same model to ensure cache hit
messages=[{"role": "user", "content": prompt}],
caching=True
)
logger.info(f"Second call result: {result2.choices[0]['message']['content']}")
# Different prompt (should miss cache)
different_prompt = "What is the capital of Japan?"
result3 = await litellm_completion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": different_prompt}]
)
logger.info(f"Different prompt result: {result3.choices[0]['message']['content']}")
async def main():
await redis_cache.connect()
await usage_example()
await redis_cache.close()
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment