Last active
August 18, 2025 20:41
-
-
Save grahama1970/c8303b598c5cc0f64fb0f2d196750a9e to your computer and use it in GitHub Desktop.
Fast async Python CLI to batch run prompts via LiteLLM with robust image support. Supports local/remote images in prompts, pre-downloads and inlines them, cache-enabled with Redis fallback, and flexible prompt input (files, stdin, or inline). Uses Typer for CLI.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| LiteLLM Call - Easy async LLM batch runner with automatic image support | |
| WHAT IT DOES: | |
| - Run multiple LLM prompts in parallel for speed | |
| - Automatically detects and includes images from URLs or local files | |
| - Works with any LiteLLM-supported model (OpenAI, Anthropic, Ollama, etc.) | |
| - Handles all image processing automatically (compression, base64 encoding) | |
| QUICK START: | |
| 1. Basic text prompt: | |
| $ python litellm_call.py "What is 2+2?" | |
| 2. Multiple prompts (run in parallel): | |
| $ python litellm_call.py "What is 2+2?" "What is the capital of France?" | |
| 3. Prompt with images (auto-detected): | |
| $ python litellm_call.py "What's in this image? /path/to/image.jpg" | |
| $ python litellm_call.py "Compare: https://example.com/cat.jpg and dog.png" | |
| 4. From files: | |
| $ python litellm_call.py @prompts.txt # One prompt per line | |
| $ python litellm_call.py prompts.json # JSON array of prompts | |
| $ python litellm_call.py @prompts.jsonl # JSON Lines format | |
| 5. From stdin: | |
| $ echo "What is 2+2?" | python litellm_call.py --stdin | |
| $ cat prompts.jsonl | python litellm_call.py --stdin --jsonl | |
| ENVIRONMENT SETUP: | |
| - OLLAMA_DEFAULT_MODEL: Model to use (default: "ollama/gemma3:12b") | |
| - OLLAMA_BASE_URL: API endpoint (default: "http://localhost:11434") | |
| - OLLAMA_API_KEY: API key if required | |
| ADVANCED USAGE: | |
| - Override model: --model "gpt-4" | |
| - Custom API: --api-base "https://api.openai.com/v1" | |
| - With API key: --api-key "sk-..." | |
| INPUT FORMATS: | |
| 1. Simple string: "What is 2+2?" | |
| 2. With image: {"text": "Explain this", "image": "path/to/image.jpg"} | |
| 3. Full control: {"model": "gpt-4", "messages": [...], "temperature": 0.7} | |
| FEATURES: | |
| - Automatic image detection in prompts (URLs and file paths) | |
| - Smart image compression to stay under API limits | |
| - Parallel processing with progress bar | |
| - Automatic retries on failures | |
| - Silent handling of missing/broken images | |
| - Supports all common image formats (jpg, png, gif, etc.) | |
| """ | |
| import asyncio | |
| import sys | |
| import json | |
| import base64 | |
| import io | |
| import os | |
| import re | |
| from pathlib import Path | |
| from typing import List, Tuple, Any, Dict | |
| from copy import deepcopy | |
| import httpx | |
| from PIL import Image | |
| from litellm import acompletion | |
| from tenacity import retry, stop_after_attempt, wait_exponential | |
| from tqdm.asyncio import tqdm | |
| from loguru import logger | |
| from dotenv import load_dotenv, find_dotenv | |
| from urlextract import URLExtract | |
| import typer | |
| from strip_tags import strip_tags | |
| logger.remove() | |
| logger.add(sys.stderr, level="WARNING") | |
| from lean4_prover.utils.litellm_cache import initialize_litellm_cache | |
| load_dotenv(find_dotenv()) | |
| initialize_litellm_cache() | |
| # ----------------------------------------------------------------------------- | |
| # Typer app (NEW) | |
| # ----------------------------------------------------------------------------- | |
| cli = typer.Typer( | |
| name="litellm_call", | |
| help="Fast async LLM batch runner with inline image support via LiteLLM / Ollama.", | |
| ) | |
| # Default model configuration - works with any LiteLLM provider | |
| MODEL = os.getenv("LITELLM_MODEL", os.getenv("OLLAMA_DEFAULT_MODEL", "ollama/gemma3:12b")) | |
| # Provider-specific configurations (LiteLLM will use the appropriate ones) | |
| # For Ollama | |
| OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") | |
| OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY", "") | |
| # For Moonshot/Kimi | |
| MOONSHOT_API_KEY = os.getenv("MOONSHOT_API_KEY", "") | |
| MOONSHOT_API_BASE = os.getenv("MOONSHOT_API_BASE", "https://api.moonshot.ai/v1") | |
| # For OpenAI | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") | |
| # For Anthropic | |
| ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") | |
| IMAGE_EXT = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"} | |
| extractor = URLExtract() | |
| SHOW_PROGRESS = os.getenv("LITELLM_NO_PROGRESS", "").lower() not in {"1", "true", "yes"} | |
| # ----------------------------------------------------------------------------- | |
| # Helpers | |
| # ----------------------------------------------------------------------------- | |
| def safe_image(path: Path) -> bool: | |
| """True if file exists, has an image extension, and PIL can open it.""" | |
| try: | |
| return path.exists() and path.suffix.lower() in IMAGE_EXT and Image.open(path).verify() is None | |
| except Exception: | |
| return False | |
| def extract_images(text: str) -> tuple[List[str], str]: | |
| """ | |
| Return: | |
| - list[str] of all valid image URLs/paths (remote & local) | |
| - cleaned prompt text with placeholders {Image 1}, {Image 2}, … | |
| """ | |
| found, seen = [], set() | |
| # 1) Strip XML/HTML tags ---------------------------------------------------- | |
| plain = strip_tags(text) | |
| # 2) Remote URLs ----------------------------------------------------------- | |
| for url in extractor.find_urls(plain): | |
| url = url.strip() | |
| if url.lower().endswith(tuple(IMAGE_EXT)) and url not in seen: | |
| found.append(url) | |
| seen.add(url) | |
| # 3) Local files ----------------------------------------------------------- | |
| tokens = re.findall(r'(?:"[^"]*"|\'[^\']*\'|\S+)', plain) | |
| for tok in tokens: | |
| tok = tok.strip('"\'') | |
| if not tok: | |
| continue | |
| candidate = Path(tok).expanduser().resolve() | |
| if safe_image(candidate) and str(candidate) not in seen: | |
| found.append(str(candidate)) | |
| seen.add(str(candidate)) | |
| # 4) Build cleaned prompt with placeholders -------------------------------- | |
| cleaned = text | |
| for idx, img in enumerate(found, 1): | |
| placeholder = f"{{Image {idx}}}" | |
| cleaned = cleaned.replace(img, placeholder) | |
| cleaned = re.sub(r"\s{2,}", " ", cleaned).strip() | |
| return found, cleaned | |
| def compress_image(path_str: str, max_kb: int = 1000) -> str: | |
| """Return base-64 data-URI for a *local* image, compressed if required.""" | |
| path = Path(path_str) | |
| img_bytes = path.read_bytes() | |
| max_bytes = max_kb * 1024 | |
| if len(img_bytes) <= max_bytes: | |
| mime = f"image/{path.suffix[1:]}" | |
| return f"data:{mime};base64,{base64.b64encode(img_bytes).decode()}" | |
| img = Image.open(io.BytesIO(img_bytes)) | |
| quality = 85 | |
| while quality > 20: | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=quality, optimize=True) | |
| if len(buf.getvalue()) <= max_bytes: | |
| return f"data:image/jpeg;base64,{base64.b64encode(buf.getvalue()).decode()}" | |
| quality -= 10 | |
| img.thumbnail((img.width // 2, img.height // 2)) | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=30) | |
| return f"data:image/jpeg;base64,{base64.b64encode(buf.getvalue()).decode()}" | |
| def fetch_remote_image(url: str) -> str | None: | |
| """Download remote image and return base-64 data-URI or None on failure.""" | |
| try: | |
| r = httpx.get(url, timeout=10) | |
| r.raise_for_status() | |
| mime = r.headers.get("content-type", "image/jpeg").split(";")[0] | |
| return f"data:{mime};base64,{base64.b64encode(r.content).decode()}" | |
| except Exception as e: | |
| logger.warning(f"Skipping remote image {url}: {e}") | |
| return None | |
| # ----------------------------------------------------------------------------- | |
| # LITELLM UTILITIES | |
| # ----------------------------------------------------------------------------- | |
| def _build_params(model: str, | |
| messages: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Return the final dict for acompletion, injecting only needed keys.""" | |
| params = {"model": model, "messages": messages} | |
| # LiteLLM auto-detects most providers, but Ollama needs an explicit base URL | |
| if model.startswith("ollama/"): | |
| params["api_base"] = OLLAMA_BASE_URL | |
| if OLLAMA_API_KEY: | |
| params["api_key"] = OLLAMA_API_KEY | |
| return params | |
| # ----------------------------------------------------------------------------- | |
| # LLM call with retry | |
| # ----------------------------------------------------------------------------- | |
| @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=4)) | |
| async def _call(params: Dict[str, Any], idx: int) -> Tuple[int, str]: | |
| resp = await acompletion(**params) | |
| return idx, resp.choices[0].message.content | |
| # ----------------------------------------------------------------------------- | |
| # Batch runner | |
| # ----------------------------------------------------------------------------- | |
| async def litellm_call(prompts: List[Any]) -> List[str]: | |
| """Run prompts: strings, dicts, or raw LiteLLM dicts.""" | |
| from copy import deepcopy | |
| # Accept a single prompt as well | |
| if isinstance(prompts, (str, dict)): | |
| prompts = [prompts] | |
| tasks: List[asyncio.Task] = [] | |
| for idx, item in enumerate(prompts): | |
| # 1) Raw LiteLLM dict ---------------------------------------------------- | |
| if isinstance(item, dict) and "messages" in item: | |
| item.setdefault("model", MODEL) # fill only if missing | |
| tasks.append(asyncio.create_task(_call(item, idx))) | |
| continue | |
| # 2) Shorthand dict or plain string ------------------------------------- | |
| if isinstance(item, dict): | |
| text = str(item.get("text", "")) | |
| images = [str(item["image"])] if "image" in item else [] | |
| model = item.get("model", MODEL) | |
| else: | |
| images, text = extract_images(str(item)) | |
| model = MODEL | |
| # Build OpenAI-style content | |
| content_parts: List[Dict[str, Any]] = [{"type": "text", "text": text}] | |
| for img in images: | |
| url = fetch_remote_image(img) if img.startswith("http") else compress_image(img) | |
| if url: | |
| content_parts.append({"type": "image_url", "image_url": {"url": url}}) | |
| params = _build_params(model, [{"role": "user", "content": content_parts}]) | |
| tasks.append(asyncio.create_task(_call(params, idx))) | |
| # Collect results in original order | |
| results = [None] * len(tasks) | |
| for coro in tqdm( | |
| asyncio.as_completed(tasks), | |
| total=len(tasks), | |
| desc="Processing", | |
| disable=not SHOW_PROGRESS | |
| ): | |
| idx, answer = await coro | |
| results[idx] = answer | |
| # Redact secrets before logging | |
| safe_prompt = deepcopy(prompts[idx]) | |
| if isinstance(safe_prompt, dict) and "api_key" in safe_prompt: | |
| safe_prompt["api_key"] = "***" | |
| logger.info(f"\nQ{idx}: {str(safe_prompt)[:50]}...\nA{idx}: {answer[:100]}...") | |
| return results | |
| # ----------------------------------------------------------------------------- | |
| # Typer Cli | |
| # ----------------------------------------------------------------------------- | |
| # @cli.callback() | |
| # def main(): | |
| # pass | |
| @cli.command() | |
| def main( | |
| sources: List[str] = typer.Argument(None, help="Prompts or files containing prompts"), | |
| model: str = typer.Option(MODEL, "--model", "-m", help="LiteLLM model name (e.g., 'ollama/gemma3:12b', 'moonshot/kimi-k2-turbo-preview', 'gpt-4')"), | |
| api_base: str = typer.Option(None, "--api-base", help="Override API base URL (auto-detected from model)"), | |
| api_key: str = typer.Option(None, "--api-key", help="Override API key (auto-detected from environment)"), | |
| stdin: bool = typer.Option(False, "--stdin", help="Read prompts from stdin"), | |
| jsonl: bool = typer.Option(False, "--jsonl", help="Input is in JSON Lines format"), | |
| ): | |
| """ | |
| Run any combination of prompts via LiteLLM. | |
| Examples\n | |
| --------\n | |
| # Basic usage with default model | |
| litellm_call "What is 2+2?"\n | |
| # Use different models | |
| litellm_call --model "moonshot/kimi-k2-turbo-preview" "Explain quantum physics"\n | |
| litellm_call --model "gpt-4" "Write a poem"\n | |
| litellm_call --model "ollama/llama2" "Hello world"\n | |
| # From files | |
| litellm_call @questions.txt\n | |
| litellm_call prompts.json\n | |
| cat lines.jsonl | litellm_call --stdin --jsonl\n | |
| """ | |
| # Update globals if CLI overrides provided | |
| global MODEL | |
| if model: | |
| MODEL = model | |
| # Set provider-specific overrides if provided | |
| if api_base: | |
| os.environ["OLLAMA_API_BASE"] = api_base | |
| os.environ["MOONSHOT_API_BASE"] = api_base | |
| os.environ["OPENAI_API_BASE"] = api_base | |
| if api_key: | |
| # Determine which provider based on model prefix | |
| if model.startswith("ollama/"): | |
| os.environ["OLLAMA_API_KEY"] = api_key | |
| elif model.startswith("moonshot/"): | |
| os.environ["MOONSHOT_API_KEY"] = api_key | |
| elif model.startswith("gpt") or model.startswith("text-"): | |
| os.environ["OPENAI_API_KEY"] = api_key | |
| elif model.startswith("claude"): | |
| os.environ["ANTHROPIC_API_KEY"] = api_key | |
| prompts: List[Any] = [] | |
| # 1) STDIN | |
| if stdin or (sources == ["-"]): | |
| for line in sys.stdin: | |
| line = line.rstrip("\n") | |
| if jsonl: | |
| prompts.append(json.loads(line)) | |
| else: | |
| prompts.append(line) | |
| # 2) Positional sources | |
| for src in sources or []: | |
| if src == "-": | |
| continue # already handled above | |
| # @file expansion | |
| if src.startswith("@"): | |
| src = src[1:] | |
| path = Path(src) | |
| if not path.exists(): | |
| # Treat literal string | |
| prompts.append(src) | |
| continue | |
| # Decide how to parse the file | |
| if path.suffix.lower() == ".json": | |
| prompts.extend(json.loads(path.read_text())) | |
| elif path.suffix.lower() == ".jsonl" or jsonl: | |
| prompts.extend(json.loads(l) for l in path.read_text().splitlines() if l.strip()) | |
| else: | |
| prompts.extend(path.read_text().splitlines()) | |
| if not prompts: | |
| typer.echo("No prompts provided.", err=True) | |
| raise typer.Exit(1) | |
| results = asyncio.run(litellm_call(prompts)) | |
| for r in results: | |
| typer.echo(r) | |
| # --------------------------------------------------------------------------- | |
| # Quick test | |
| # --------------------------------------------------------------------------- | |
| async def demo() -> List[str]: | |
| """ | |
| Run a canned set of prompts and return the results. | |
| Safe to call from other async code. | |
| """ | |
| prompts = [ | |
| "What is the capital of France?", | |
| "Calculate 15+27+38", | |
| "What is 3 + 5? Return JSON: {question:string,answer:number}", | |
| "What is this animal eating? proof_of_concept/ollama_turbo/images/image2.png", | |
| "Describe https://upload.wikimedia.org/wikipedia/commons/thumb/9/90/Labrador_Retriever_portrait.jpg/960px-Labrador_Retriever_portrait.jpg and https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/960px-Cat_November_2010-1a.jpg", | |
| {"text": "Explain this meme", "image": "proof_of_concept/ollama_turbo/images/image.png"}, | |
| { | |
| "model": "ollama/gpt-oss:120b", | |
| "api_base": "https://ollama.com", | |
| "messages": [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": "Tell me a short joke."} | |
| ], | |
| "temperature": 1.0 | |
| } | |
| ] | |
| return await litellm_call(prompts) | |
| def demo_sync() -> List[str]: | |
| """ | |
| Synchronous wrapper around `demo()` for callers that are not async. | |
| """ | |
| return asyncio.run(demo()) | |
| if __name__ == "__main__": | |
| # cli() | |
| demo_sync() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Initializes LiteLLM Caching Configuration. | |
| Module: litellm_cache.py | |
| Description: Functions for litellm cache operations | |
| This module sets up LiteLLM's caching mechanism. It attempts to configure' | |
| Redis as the primary cache backend. If Redis is unavailable or fails connection | |
| tests, it falls back to using LiteLLM's built-in in-memory cache. Includes' | |
| a test function to verify cache functionality. | |
| Relevant Documentation: | |
| - LiteLLM Caching: https://docs.litellm.ai/docs/proxy/caching | |
| - Redis Python Client: https://redis.io/docs/clients/python/ | |
| - Project Caching Notes: ../../repo_docs/caching_strategy.md (Placeholder) | |
| Input/Output: | |
| - Input: Environment variables for Redis connection (optional). | |
| - Output: Configures LiteLLM global cache settings. Logs status messages. | |
| - The `test_litellm_cache` function demonstrates usage by making cached calls. | |
| """ | |
| # ============================================================================== | |
| # !!! WARNING - DO NOT MODIFY THIS FILE !!! | |
| # ============================================================================== | |
| # This is a core functionality file that other parts of the system depend on. | |
| # Changing this file (especially its sync/async nature) will break multiple dependent systems. | |
| # Any changes here will cascade into test failures across the codebase. | |
| # | |
| # If you think you need to modify this file: | |
| # 1. DON'T. The synchronous implementation is intentional | |
| # 2. The caching system is working as designed | |
| # 3. Test files should adapt to this implementation, not vice versa | |
| # 4. Consult LESSONS_LEARNED.md about not breaking working code | |
| # ============================================================================== | |
| import os | |
| import sys | |
| # from litellm.caching import Cache, Type # Import Cache and Type | |
| import sys # Import sys for exit codes | |
| from typing import Any # Import Any for type hint | |
| import litellm | |
| import redis | |
| from dotenv import load_dotenv # Import dotenv for environment variable loading | |
| from litellm.caching.caching import ( | |
| Cache as LiteLLMCache, | |
| ) # Import Cache and Type | |
| from litellm.caching.caching import ( | |
| LiteLLMCacheType, | |
| ) | |
| from loguru import logger | |
| from lean4_prover.utils.log_utils import truncate_large_value | |
| # load_env_file() # Removed - Docker Compose handles .env loading via env_file | |
| load_dotenv() | |
| def initialize_litellm_cache() -> None: | |
| redis_host = os.getenv("REDIS_HOST", "localhost") | |
| redis_port = int(os.getenv("REDIS_PORT", 6379)) | |
| redis_password = os.getenv( | |
| "REDIS_PASSWORD", None | |
| ) # Assuming password might be needed | |
| try: | |
| logger.debug( | |
| f"Starting LiteLLM cache initialization (Redis target: {redis_host}:{redis_port})..." | |
| ) | |
| # Test Redis connection before enabling caching | |
| logger.debug("Testing Redis connection...") | |
| test_redis = redis.Redis( | |
| host=redis_host, | |
| port=redis_port, | |
| password=redis_password, | |
| socket_timeout=2, | |
| decode_responses=True, # Added decode_responses for easier debugging if needed | |
| ) | |
| if not test_redis.ping(): | |
| raise ConnectionError( | |
| f"Redis is not responding at {redis_host}:{redis_port}." | |
| ) | |
| # Verify Redis is empty or log existing keys | |
| keys = test_redis.keys("*") | |
| if keys: | |
| # Use the truncate utility for logging keys | |
| logger.debug(f"Existing Redis keys: {truncate_large_value(keys)}") | |
| else: | |
| logger.debug("Redis cache is empty") | |
| # Set up LiteLLM cache with debug logging | |
| logger.debug("Configuring LiteLLM Redis cache...") | |
| litellm.cache = LiteLLMCache( # Use imported Cache | |
| type=LiteLLMCacheType.REDIS, # Use Enum/Type | |
| host=redis_host, | |
| port=str(redis_port), # Ensure port is a string | |
| password=redis_password, | |
| supported_call_types=["acompletion", "completion"], | |
| ttl=60 * 60 * 24 * 2, # 2 days | |
| ) | |
| # Enable caching and verify | |
| logger.debug("Enabling LiteLLM cache...") | |
| litellm.enable_cache() | |
| # Set debug logging for LiteLLM | |
| os.environ["LITELLM_LOG"] = "DEBUG" | |
| # Verify cache configuration | |
| logger.debug( | |
| f"LiteLLM cache config: {litellm.cache.__dict__ if hasattr(litellm.cache, '__dict__') else 'No cache config available'}" | |
| ) | |
| logger.info(" Redis caching enabled on localhost:6379") | |
| # Try a test set/get to verify Redis is working | |
| try: | |
| test_key = "litellm_cache_test" | |
| test_redis.set(test_key, "test_value", ex=60) | |
| result = test_redis.get(test_key) | |
| logger.debug(f"Redis test write/read successful: {result == 'test_value'}") | |
| test_redis.delete(test_key) | |
| except Exception as e: | |
| logger.warning(f"Redis test write/read failed: {e}") | |
| except (redis.ConnectionError, redis.TimeoutError) as e: | |
| logger.warning( | |
| f"⚠️ Redis connection failed: {e}. Falling back to in-memory caching." | |
| ) | |
| # Fall back to in-memory caching if Redis is unavailable | |
| logger.debug("Configuring in-memory cache fallback...") | |
| litellm.cache = LiteLLMCache(type=LiteLLMCacheType.LOCAL) # Use Enum/Type | |
| litellm.enable_cache() | |
| logger.debug("In-memory cache enabled") | |
| def test_litellm_cache() -> tuple[bool, dict[str, bool | None]]: | |
| """ | |
| Test the LiteLLM cache functionality with a sample completion call. | |
| Returns a tuple: (overall_success, cache_hit_details) | |
| """ | |
| initialize_litellm_cache() | |
| test_success = False | |
| # Explicitly annotate the dictionary type | |
| cache_details: dict[str, bool | None] = {"cache_hit1": None, "cache_hit2": None} | |
| try: | |
| # Test the cache with a simple completion call | |
| test_messages = [ | |
| { | |
| "role": "user", | |
| "content": "What is the capital of France? Respond concisely.", | |
| } | |
| ] | |
| logger.info("Testing cache with completion call...") | |
| # First call - Expect cache miss (cache_hit should be False or None) | |
| response1 = litellm.completion( | |
| model=os.getenv("LITELLM_TEST_MODEL", os.getenv("LEAN4_MODEL")), # Use test model or main model | |
| messages=test_messages, | |
| # Ensure caching is attempted, remove specific cache param unless needed for override | |
| ) | |
| usage1 = getattr(response1, "usage", "N/A") | |
| hidden_params1 = getattr(response1, "_hidden_params", {}) | |
| cache_hit1 = hidden_params1.get( | |
| "cache_hit" | |
| ) # Could be None if not hit or feature disabled | |
| cache_details["cache_hit1"] = cache_hit1 | |
| logger.info(f"First call usage: {usage1}") | |
| logger.info(f"Response 1: Cache hit: {cache_hit1}") | |
| # Check if first call was NOT a cache hit (allow None or False) | |
| first_call_missed = cache_hit1 is None or cache_hit1 is False | |
| # Second call - Expect cache hit (cache_hit should be True) | |
| response2 = litellm.completion( | |
| model=os.getenv("LITELLM_TEST_MODEL", os.getenv("LEAN4_MODEL")), # Use test model or main model | |
| messages=test_messages, | |
| # Ensure caching is attempted | |
| ) | |
| usage2 = getattr(response2, "usage", "N/A") | |
| hidden_params2 = getattr(response2, "_hidden_params", {}) | |
| cache_hit2 = hidden_params2.get("cache_hit") | |
| cache_details["cache_hit2"] = cache_hit2 | |
| logger.info(f"Second call usage: {usage2}") | |
| logger.info(f"Response 2: Cache hit: {cache_hit2}") | |
| # Check if second call WAS a cache hit | |
| second_call_hit = cache_hit2 is True | |
| # Determine overall test success | |
| test_success = first_call_missed and second_call_hit | |
| except Exception as e: | |
| logger.error(f"Cache test failed during execution: {e}") | |
| test_success = False # Ensure failure is recorded | |
| return test_success, cache_details | |
| if __name__ == "__main__": | |
| logger.info("--- Running Standalone Validation for initialize_litellm_cache.py ---") | |
| tests_passed_count = 0 | |
| tests_failed_count = 0 | |
| total_tests = 1 | |
| try: | |
| test_result, details = test_litellm_cache() | |
| if test_result: | |
| tests_passed_count += 1 | |
| logger.success(" Test 'cache_hit_miss': PASSED") | |
| else: | |
| tests_failed_count += 1 | |
| logger.error(" Test 'cache_hit_miss': FAILED") | |
| logger.error( | |
| " Expected first call cache_hit=False/None, second call cache_hit=True." | |
| ) | |
| logger.error( | |
| f" Got: cache_hit1={details.get('cache_hit1')}, cache_hit2={details.get('cache_hit2')}" | |
| ) | |
| except Exception as e: | |
| tests_failed_count += 1 # Count exception as failure | |
| logger.error( | |
| " Test 'cache_hit_miss': FAILED due to exception during test execution." | |
| ) | |
| logger.error(f" Exception: {e}", exc_info=True) | |
| # --- Report validation status --- | |
| print("\n--- Test Summary ---") | |
| print(f"Total Tests: {total_tests}") | |
| print(f"Passed: {tests_passed_count}") | |
| print(f"Failed: {tests_failed_count}") | |
| if tests_failed_count == 0: | |
| print("\n VALIDATION COMPLETE - All LiteLLM cache tests passed.") | |
| sys.exit(0) | |
| else: | |
| print("\n VALIDATION FAILED - LiteLLM cache test failed.") | |
| # Error details already logged above | |
| sys.exit(1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| LiteLLM Call - Easy async LLM batch runner with automatic image support | |
| WHAT IT DOES: | |
| - Run multiple LLM prompts in parallel for speed | |
| - Automatically detects and includes images from URLs or local files | |
| - Works with any LiteLLM-supported model (OpenAI, Anthropic, Ollama, etc.) | |
| - Handles all image processing automatically (compression, base64 encoding) | |
| QUICK START: | |
| 1. Basic text prompt: | |
| $ python litellm_call.py "What is 2+2?" | |
| 2. Multiple prompts (run in parallel): | |
| $ python litellm_call.py "What is 2+2?" "What is the capital of France?" | |
| 3. Prompt with images (auto-detected): | |
| $ python litellm_call.py "What's in this image? /path/to/image.jpg" | |
| $ python litellm_call.py "Compare: https://example.com/cat.jpg and dog.png" | |
| 4. From files: | |
| $ python litellm_call.py @prompts.txt # One prompt per line | |
| $ python litellm_call.py prompts.json # JSON array of prompts | |
| $ python litellm_call.py @prompts.jsonl # JSON Lines format | |
| 5. From stdin: | |
| $ echo "What is 2+2?" | python litellm_call.py --stdin | |
| $ cat prompts.jsonl | python litellm_call.py --stdin --jsonl | |
| ENVIRONMENT SETUP: | |
| - OLLAMA_DEFAULT_MODEL: Model to use (default: "ollama/gemma3:12b") | |
| - OLLAMA_BASE_URL: API endpoint (default: "http://localhost:11434") | |
| - OLLAMA_API_KEY: API key if required | |
| ADVANCED USAGE: | |
| - Override model: --model "gpt-4" | |
| - Custom API: --api-base "https://api.openai.com/v1" | |
| - With API key: --api-key "sk-..." | |
| INPUT FORMATS: | |
| 1. Simple string: "What is 2+2?" | |
| 2. With image: {"text": "Explain this", "image": "path/to/image.jpg"} | |
| 3. Full control: {"model": "gpt-4", "messages": [...], "temperature": 0.7} | |
| FEATURES: | |
| - Automatic image detection in prompts (URLs and file paths) | |
| - Smart image compression to stay under API limits | |
| - Parallel processing with progress bar | |
| - Automatic retries on failures | |
| - Silent handling of missing/broken images | |
| - Supports all common image formats (jpg, png, gif, etc.) | |
| """ | |
| import asyncio | |
| import sys | |
| import json | |
| import base64 | |
| import io | |
| import os | |
| import re | |
| from pathlib import Path | |
| from typing import List, Tuple, Any, Dict | |
| import httpx | |
| from PIL import Image | |
| from litellm import acompletion | |
| from tenacity import retry, stop_after_attempt, wait_exponential | |
| from tqdm.asyncio import tqdm | |
| from loguru import logger | |
| from dotenv import load_dotenv, find_dotenv | |
| from urlextract import URLExtract | |
| import typer | |
| from strip_tags import strip_tags | |
| logger.remove() | |
| logger.add(sys.stderr, level="WARNING") | |
| from lean4_prover.utils.litellm_cache import initialize_litellm_cache | |
| load_dotenv(find_dotenv()) | |
| initialize_litellm_cache() | |
| # ----------------------------------------------------------------------------- | |
| # Typer app (NEW) | |
| # ----------------------------------------------------------------------------- | |
| cli = typer.Typer( | |
| name="litellm_call", | |
| help="Fast async LLM batch runner with inline image support via LiteLLM / Ollama.", | |
| ) | |
| MODEL = os.getenv("OLLAMA_DEFAULT_MODEL", "ollama/gemma3:12b") | |
| OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") | |
| OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY", "") | |
| IMAGE_EXT = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"} | |
| extractor = URLExtract() | |
| # ----------------------------------------------------------------------------- | |
| # Helpers | |
| # ----------------------------------------------------------------------------- | |
| def safe_image(path: Path) -> bool: | |
| """True if file exists, has an image extension, and PIL can open it.""" | |
| try: | |
| return path.exists() and path.suffix.lower() in IMAGE_EXT and Image.open(path).verify() is None | |
| except Exception: | |
| return False | |
| def extract_images(text: str) -> tuple[List[str], str]: | |
| """ | |
| Return: | |
| - list[str] of all valid image URLs/paths (remote & local) | |
| - cleaned prompt text with placeholders {Image 1}, {Image 2}, … | |
| """ | |
| found, seen = [], set() | |
| # 1) Strip XML/HTML tags ---------------------------------------------------- | |
| plain = strip_tags(text) | |
| # 2) Remote URLs ----------------------------------------------------------- | |
| for url in extractor.find_urls(plain): | |
| url = url.strip() | |
| if url.lower().endswith(tuple(IMAGE_EXT)) and url not in seen: | |
| found.append(url) | |
| seen.add(url) | |
| # 3) Local files ----------------------------------------------------------- | |
| tokens = re.findall(r'(?:"[^"]*"|\'[^\']*\'|\S+)', plain) | |
| for tok in tokens: | |
| tok = tok.strip('"\'') | |
| if not tok: | |
| continue | |
| candidate = Path(tok).expanduser().resolve() | |
| if safe_image(candidate) and str(candidate) not in seen: | |
| found.append(str(candidate)) | |
| seen.add(str(candidate)) | |
| # 4) Build cleaned prompt with placeholders -------------------------------- | |
| cleaned = text | |
| for idx, img in enumerate(found, 1): | |
| placeholder = f"{{Image {idx}}}" | |
| cleaned = cleaned.replace(img, placeholder) | |
| cleaned = re.sub(r"\s{2,}", " ", cleaned).strip() | |
| return found, cleaned | |
| def compress_image(path_str: str, max_kb: int = 1000) -> str: | |
| """Return base-64 data-URI for a *local* image, compressed if required.""" | |
| path = Path(path_str) | |
| img_bytes = path.read_bytes() | |
| max_bytes = max_kb * 1024 | |
| if len(img_bytes) <= max_bytes: | |
| mime = f"image/{path.suffix[1:]}" | |
| return f"data:{mime};base64,{base64.b64encode(img_bytes).decode()}" | |
| img = Image.open(io.BytesIO(img_bytes)) | |
| quality = 85 | |
| while quality > 20: | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=quality, optimize=True) | |
| if len(buf.getvalue()) <= max_bytes: | |
| return f"data:image/jpeg;base64,{base64.b64encode(buf.getvalue()).decode()}" | |
| quality -= 10 | |
| img.thumbnail((img.width // 2, img.height // 2)) | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=30) | |
| return f"data:image/jpeg;base64,{base64.b64encode(buf.getvalue()).decode()}" | |
| def fetch_remote_image(url: str) -> str | None: | |
| """Download remote image and return base-64 data-URI or None on failure.""" | |
| try: | |
| r = httpx.get(url, timeout=10) | |
| r.raise_for_status() | |
| mime = r.headers.get("content-type", "image/jpeg").split(";")[0] | |
| return f"data:{mime};base64,{base64.b64encode(r.content).decode()}" | |
| except Exception as e: | |
| logger.warning(f"Skipping remote image {url}: {e}") | |
| return None | |
| # ----------------------------------------------------------------------------- | |
| # LLM call with retry | |
| # ----------------------------------------------------------------------------- | |
| @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=4)) | |
| async def _call(params: Dict[str, Any], idx: int) -> Tuple[int, str]: | |
| resp = await acompletion(**params) | |
| return idx, resp.choices[0].message.content | |
| # ----------------------------------------------------------------------------- | |
| # Batch runner | |
| # ----------------------------------------------------------------------------- | |
| async def litellm_call(prompts: List[Any]) -> List[str]: | |
| """Run prompts: strings, dicts, or raw LiteLLM dicts.""" | |
| # Accept a single prompt as well | |
| if isinstance(prompts, (str, dict)): | |
| prompts = [prompts] | |
| tasks: List[asyncio.Task] = [] | |
| for idx, item in enumerate(prompts): | |
| # Raw LiteLLM dict | |
| if isinstance(item, dict) and "messages" in item: | |
| defaults = { | |
| "model": MODEL, | |
| "api_base": OLLAMA_BASE_URL, | |
| "api_key": OLLAMA_API_KEY, | |
| } | |
| params = {**defaults, **item} # user keys override env | |
| tasks.append(asyncio.create_task(_call(params, idx))) | |
| continue | |
| # Parse text & images | |
| if isinstance(item, dict): | |
| text = str(item.get("text", "")) | |
| images = [str(item["image"])] if "image" in item else [] | |
| else: | |
| images, text = extract_images(str(item)) # returns (list[str], cleaned) | |
| # Build OpenAI-style content | |
| content_parts: List[Dict[str, Any]] = [{"type": "text", "text": text}] | |
| for img in images: | |
| if img.startswith("http"): | |
| url = fetch_remote_image(img) | |
| if url is None: | |
| continue # skip unreachable | |
| else: | |
| url = compress_image(img) | |
| content_parts.append({"type": "image_url", "image_url": {"url": url}}) | |
| params = { | |
| "model": MODEL, | |
| "api_base": OLLAMA_BASE_URL, | |
| "api_key": OLLAMA_API_KEY, | |
| "messages": [{"role": "user", "content": content_parts}], | |
| } | |
| tasks.append(asyncio.create_task(_call(params, idx))) | |
| results = [None] * len(tasks) | |
| for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing"): | |
| idx, answer = await coro | |
| results[idx] = answer | |
| logger.info(f"\nQ{idx}: {str(prompts[idx])[:50]}...\nA{idx}: {answer[:100]}...") | |
| return results | |
| # ----------------------------------------------------------------------------- | |
| # Typer Cli | |
| # ----------------------------------------------------------------------------- | |
| # @cli.callback() | |
| # def main(): | |
| # pass | |
| @cli.command() | |
| def main( | |
| sources: List[str] = typer.Argument(None, help="…"), | |
| model: str = typer.Option(MODEL, "--model", "-m", help="…"), | |
| api_base: str = typer.Option(OLLAMA_BASE_URL, "--api-base", help="…"), | |
| api_key: str = typer.Option(OLLAMA_API_KEY, "--api-key", help="…"), | |
| stdin: bool = typer.Option(False, "--stdin", help="…"), | |
| jsonl: bool = typer.Option(False, "--jsonl", help="…"), | |
| ): | |
| """ | |
| Run any combination of prompts via LiteLLM. | |
| Examples\n | |
| --------\n | |
| litellm_call run "What is 2+2?"\n | |
| litellm_call run @questions.txt\n | |
| litellm_call run prompts.json\n | |
| cat lines.jsonl | litellm_call run --stdin --jsonl\n | |
| """ | |
| print(">>> typer sources =", sources) | |
| prompts: List[Any] = [] | |
| # 1) STDIN | |
| if stdin or (sources == ["-"]): | |
| for line in sys.stdin: | |
| line = line.rstrip("\n") | |
| if jsonl: | |
| prompts.append(json.loads(line)) | |
| else: | |
| prompts.append(line) | |
| # 2) Positional sources | |
| for src in sources or []: | |
| if src == "-": | |
| continue # already handled above | |
| # @file expansion | |
| if src.startswith("@"): | |
| src = src[1:] | |
| path = Path(src) | |
| if not path.exists(): | |
| # Treat literal string | |
| prompts.append(src) | |
| continue | |
| # Decide how to parse the file | |
| if path.suffix.lower() == ".json": | |
| prompts.extend(json.loads(path.read_text())) | |
| elif path.suffix.lower() == ".jsonl" or jsonl: | |
| prompts.extend(json.loads(l) for l in path.read_text().splitlines() if l.strip()) | |
| else: | |
| prompts.extend(path.read_text().splitlines()) | |
| if not prompts: | |
| typer.echo("No prompts provided.", err=True) | |
| raise typer.Exit(1) | |
| # TEMPORARY DEBUG | |
| print("*** PROMPT TO MODEL:", prompts) | |
| results = asyncio.run(litellm_call(prompts)) | |
| for r in results: | |
| typer.echo(r) | |
| # --------------------------------------------------------------------------- | |
| # Quick test | |
| # --------------------------------------------------------------------------- | |
| def debug() -> None: | |
| prompts = [ | |
| "What is the capital of France?", | |
| "Calculate 15+27+38", | |
| "What is 3 + 5? Return JSON: {question:string,answer:number}", | |
| "What is this animal eating? proof_of_concept/ollama_turbo/images/image2.png", | |
| "Describe https://upload.wikimedia.org/wikipedia/commons/thumb/9/90/Labrador_Retriever_portrait.jpg/960px-Labrador_Retriever_portrait.jpg and https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/960px-Cat_November_2010-1a.jpg", | |
| {"text": "Explain this meme", "image": "proof_of_concept/ollama_turbo/images/image.png"}, | |
| { | |
| "model": "ollama/gpt-oss:120b", | |
| "api_base": "https://ollama.com", | |
| "messages": [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": "Tell me a short joke."} | |
| ], | |
| "temperature": 1.0 | |
| } | |
| ] | |
| results = asyncio.run(litellm_call(prompts)) | |
| print("\nFinal Results:") | |
| for i, r in enumerate(results, 1): | |
| print(f"{i}. {r}") | |
| DEBUG=False | |
| if __name__ == "__main__": | |
| # print(f"DEBUG -- sys.argv received by script: {sys.argv}") | |
| debug() if DEBUG else cli() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment