Created
October 23, 2024 15:00
-
-
Save grahama1970/4975d648d7a96f8e560221fbb6f411a1 to your computer and use it in GitHub Desktop.
Added a Different Redis implementation to LiteLLM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from types import SimpleNamespace | |
import litellm | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm import completion, acompletion | |
import asyncio | |
from functools import wraps | |
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential | |
from litellm import RateLimitError, APIError | |
import os | |
from dotenv import load_dotenv | |
from loguru import logger | |
import aioredis | |
import atexit | |
load_dotenv('./.env') | |
class RedisCache: | |
def __init__(self, host='localhost', port=6379, password=None): | |
self.host = host | |
self.port = port | |
self.password = password | |
self.client = None | |
async def connect(self): | |
self.client = await aioredis.from_url( | |
f"redis://{self.host}:{self.port}", | |
password=self.password | |
) | |
# Register automatic cleanup | |
atexit.register(self._close_sync) | |
async def get(self, key): | |
if self.client: | |
return await self.client.get(key) | |
async def set(self, key, value): | |
if self.client: | |
await self.client.set(key, value) | |
async def close(self): | |
if self.client: | |
await self.client.close() | |
logger.info("Redis connection closed.") | |
def _close_sync(self): | |
"""Ensure Redis closes automatically on exit, even in non-async contexts.""" | |
if self.client: | |
asyncio.run(self.close()) | |
# Initialize Redis Cache | |
redis_cache = RedisCache( | |
host=os.environ.get('REDIS_HOST', 'localhost'), | |
port=int(os.environ.get('REDIS_PORT', 6379)), | |
password=os.environ.get('REDIS_PASSWORD', 'openSesame') | |
) | |
# Custom logger to handle cache hit/miss logging | |
class CacheLogger(CustomLogger): | |
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
cache_hit = kwargs.get('cache_hit', False) | |
if cache_hit: | |
logger.info(f"Cache hit for call id: '{kwargs.get('litellm_call_id', 'unknown')}'") | |
else: | |
logger.info(f"Cache miss for call id: '{kwargs.get('litellm_call_id', 'unknown')}'") | |
# Set up the custom logger | |
litellm.callbacks = [CacheLogger()] | |
def litellm_cache_and_tenacity_retry( | |
max_retries=3, | |
wait_multiplier=1, | |
wait_min=4, | |
wait_max=10 | |
): | |
def decorator(func): | |
@retry( | |
stop=stop_after_attempt(max_retries), | |
wait=wait_exponential(multiplier=wait_multiplier, min=wait_min, max=wait_max), | |
retry=retry_if_exception_type((RateLimitError, APIError)) | |
) | |
@wraps(func) | |
async def wrapper(*args, **kwargs): | |
return await func(*args, **kwargs) | |
return wrapper | |
return decorator | |
@litellm_cache_and_tenacity_retry() | |
async def litellm_completion(**kwargs): | |
cache_key = kwargs.get("messages")[0]["content"] | |
cached_response = await redis_cache.get(cache_key) | |
if cached_response: | |
kwargs['cache_hit'] = True | |
logger.info(f"Cache hit for prompt: {cache_key}") | |
response = SimpleNamespace(choices=[{"message": {"content": cached_response.decode('utf-8')}}]) | |
return response | |
# Perform the actual completion request if not cached | |
response = await acompletion(**kwargs) | |
await redis_cache.set(cache_key, response.choices[0]["message"]["content"]) | |
return response | |
async def usage_example(): | |
prompt = "What is the capital of France?" | |
# First call (should miss cache) | |
result1 = await litellm_completion( | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": prompt}], | |
caching=True | |
) | |
logger.info(f"First call result: {result1.choices[0]['message']['content']}") | |
# Second call with the same prompt (should hit cache) | |
result2 = await litellm_completion( | |
model="gpt-4o-mini", # Use the same model to ensure cache hit | |
messages=[{"role": "user", "content": prompt}], | |
caching=True | |
) | |
logger.info(f"Second call result: {result2.choices[0]['message']['content']}") | |
# Different prompt (should miss cache) | |
different_prompt = "What is the capital of Japan?" | |
result3 = await litellm_completion( | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": different_prompt}] | |
) | |
logger.info(f"Different prompt result: {result3.choices[0]['message']['content']}") | |
async def main(): | |
await redis_cache.connect() | |
await usage_example() | |
await redis_cache.close() | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment