Skip to content

Instantly share code, notes, and snippets.

@grahama1970
Last active August 18, 2025 20:41
Show Gist options
  • Select an option

  • Save grahama1970/c8303b598c5cc0f64fb0f2d196750a9e to your computer and use it in GitHub Desktop.

Select an option

Save grahama1970/c8303b598c5cc0f64fb0f2d196750a9e to your computer and use it in GitHub Desktop.
Fast async Python CLI to batch run prompts via LiteLLM with robust image support. Supports local/remote images in prompts, pre-downloads and inlines them, cache-enabled with Redis fallback, and flexible prompt input (files, stdin, or inline). Uses Typer for CLI.
#!/usr/bin/env python3
"""
LiteLLM Call - Easy async LLM batch runner with automatic image support
WHAT IT DOES:
- Run multiple LLM prompts in parallel for speed
- Automatically detects and includes images from URLs or local files
- Works with any LiteLLM-supported model (OpenAI, Anthropic, Ollama, etc.)
- Handles all image processing automatically (compression, base64 encoding)
QUICK START:
1. Basic text prompt:
$ python litellm_call.py "What is 2+2?"
2. Multiple prompts (run in parallel):
$ python litellm_call.py "What is 2+2?" "What is the capital of France?"
3. Prompt with images (auto-detected):
$ python litellm_call.py "What's in this image? /path/to/image.jpg"
$ python litellm_call.py "Compare: https://example.com/cat.jpg and dog.png"
4. From files:
$ python litellm_call.py @prompts.txt # One prompt per line
$ python litellm_call.py prompts.json # JSON array of prompts
$ python litellm_call.py @prompts.jsonl # JSON Lines format
5. From stdin:
$ echo "What is 2+2?" | python litellm_call.py --stdin
$ cat prompts.jsonl | python litellm_call.py --stdin --jsonl
ENVIRONMENT SETUP:
- OLLAMA_DEFAULT_MODEL: Model to use (default: "ollama/gemma3:12b")
- OLLAMA_BASE_URL: API endpoint (default: "http://localhost:11434")
- OLLAMA_API_KEY: API key if required
ADVANCED USAGE:
- Override model: --model "gpt-4"
- Custom API: --api-base "https://api.openai.com/v1"
- With API key: --api-key "sk-..."
INPUT FORMATS:
1. Simple string: "What is 2+2?"
2. With image: {"text": "Explain this", "image": "path/to/image.jpg"}
3. Full control: {"model": "gpt-4", "messages": [...], "temperature": 0.7}
FEATURES:
- Automatic image detection in prompts (URLs and file paths)
- Smart image compression to stay under API limits
- Parallel processing with progress bar
- Automatic retries on failures
- Silent handling of missing/broken images
- Supports all common image formats (jpg, png, gif, etc.)
"""
import asyncio
import sys
import json
import base64
import io
import os
import re
from pathlib import Path
from typing import List, Tuple, Any, Dict
from copy import deepcopy
import httpx
from PIL import Image
from litellm import acompletion
from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm.asyncio import tqdm
from loguru import logger
from dotenv import load_dotenv, find_dotenv
from urlextract import URLExtract
import typer
from strip_tags import strip_tags
logger.remove()
logger.add(sys.stderr, level="WARNING")
from lean4_prover.utils.litellm_cache import initialize_litellm_cache
load_dotenv(find_dotenv())
initialize_litellm_cache()
# -----------------------------------------------------------------------------
# Typer app (NEW)
# -----------------------------------------------------------------------------
cli = typer.Typer(
name="litellm_call",
help="Fast async LLM batch runner with inline image support via LiteLLM / Ollama.",
)
# Default model configuration - works with any LiteLLM provider
MODEL = os.getenv("LITELLM_MODEL", os.getenv("OLLAMA_DEFAULT_MODEL", "ollama/gemma3:12b"))
# Provider-specific configurations (LiteLLM will use the appropriate ones)
# For Ollama
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY", "")
# For Moonshot/Kimi
MOONSHOT_API_KEY = os.getenv("MOONSHOT_API_KEY", "")
MOONSHOT_API_BASE = os.getenv("MOONSHOT_API_BASE", "https://api.moonshot.ai/v1")
# For OpenAI
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
# For Anthropic
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
IMAGE_EXT = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"}
extractor = URLExtract()
SHOW_PROGRESS = os.getenv("LITELLM_NO_PROGRESS", "").lower() not in {"1", "true", "yes"}
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def safe_image(path: Path) -> bool:
"""True if file exists, has an image extension, and PIL can open it."""
try:
return path.exists() and path.suffix.lower() in IMAGE_EXT and Image.open(path).verify() is None
except Exception:
return False
def extract_images(text: str) -> tuple[List[str], str]:
"""
Return:
- list[str] of all valid image URLs/paths (remote & local)
- cleaned prompt text with placeholders {Image 1}, {Image 2}, …
"""
found, seen = [], set()
# 1) Strip XML/HTML tags ----------------------------------------------------
plain = strip_tags(text)
# 2) Remote URLs -----------------------------------------------------------
for url in extractor.find_urls(plain):
url = url.strip()
if url.lower().endswith(tuple(IMAGE_EXT)) and url not in seen:
found.append(url)
seen.add(url)
# 3) Local files -----------------------------------------------------------
tokens = re.findall(r'(?:"[^"]*"|\'[^\']*\'|\S+)', plain)
for tok in tokens:
tok = tok.strip('"\'')
if not tok:
continue
candidate = Path(tok).expanduser().resolve()
if safe_image(candidate) and str(candidate) not in seen:
found.append(str(candidate))
seen.add(str(candidate))
# 4) Build cleaned prompt with placeholders --------------------------------
cleaned = text
for idx, img in enumerate(found, 1):
placeholder = f"{{Image {idx}}}"
cleaned = cleaned.replace(img, placeholder)
cleaned = re.sub(r"\s{2,}", " ", cleaned).strip()
return found, cleaned
def compress_image(path_str: str, max_kb: int = 1000) -> str:
"""Return base-64 data-URI for a *local* image, compressed if required."""
path = Path(path_str)
img_bytes = path.read_bytes()
max_bytes = max_kb * 1024
if len(img_bytes) <= max_bytes:
mime = f"image/{path.suffix[1:]}"
return f"data:{mime};base64,{base64.b64encode(img_bytes).decode()}"
img = Image.open(io.BytesIO(img_bytes))
quality = 85
while quality > 20:
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=quality, optimize=True)
if len(buf.getvalue()) <= max_bytes:
return f"data:image/jpeg;base64,{base64.b64encode(buf.getvalue()).decode()}"
quality -= 10
img.thumbnail((img.width // 2, img.height // 2))
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=30)
return f"data:image/jpeg;base64,{base64.b64encode(buf.getvalue()).decode()}"
def fetch_remote_image(url: str) -> str | None:
"""Download remote image and return base-64 data-URI or None on failure."""
try:
r = httpx.get(url, timeout=10)
r.raise_for_status()
mime = r.headers.get("content-type", "image/jpeg").split(";")[0]
return f"data:{mime};base64,{base64.b64encode(r.content).decode()}"
except Exception as e:
logger.warning(f"Skipping remote image {url}: {e}")
return None
# -----------------------------------------------------------------------------
# LITELLM UTILITIES
# -----------------------------------------------------------------------------
def _build_params(model: str,
messages: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Return the final dict for acompletion, injecting only needed keys."""
params = {"model": model, "messages": messages}
# LiteLLM auto-detects most providers, but Ollama needs an explicit base URL
if model.startswith("ollama/"):
params["api_base"] = OLLAMA_BASE_URL
if OLLAMA_API_KEY:
params["api_key"] = OLLAMA_API_KEY
return params
# -----------------------------------------------------------------------------
# LLM call with retry
# -----------------------------------------------------------------------------
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=4))
async def _call(params: Dict[str, Any], idx: int) -> Tuple[int, str]:
resp = await acompletion(**params)
return idx, resp.choices[0].message.content
# -----------------------------------------------------------------------------
# Batch runner
# -----------------------------------------------------------------------------
async def litellm_call(prompts: List[Any]) -> List[str]:
"""Run prompts: strings, dicts, or raw LiteLLM dicts."""
from copy import deepcopy
# Accept a single prompt as well
if isinstance(prompts, (str, dict)):
prompts = [prompts]
tasks: List[asyncio.Task] = []
for idx, item in enumerate(prompts):
# 1) Raw LiteLLM dict ----------------------------------------------------
if isinstance(item, dict) and "messages" in item:
item.setdefault("model", MODEL) # fill only if missing
tasks.append(asyncio.create_task(_call(item, idx)))
continue
# 2) Shorthand dict or plain string -------------------------------------
if isinstance(item, dict):
text = str(item.get("text", ""))
images = [str(item["image"])] if "image" in item else []
model = item.get("model", MODEL)
else:
images, text = extract_images(str(item))
model = MODEL
# Build OpenAI-style content
content_parts: List[Dict[str, Any]] = [{"type": "text", "text": text}]
for img in images:
url = fetch_remote_image(img) if img.startswith("http") else compress_image(img)
if url:
content_parts.append({"type": "image_url", "image_url": {"url": url}})
params = _build_params(model, [{"role": "user", "content": content_parts}])
tasks.append(asyncio.create_task(_call(params, idx)))
# Collect results in original order
results = [None] * len(tasks)
for coro in tqdm(
asyncio.as_completed(tasks),
total=len(tasks),
desc="Processing",
disable=not SHOW_PROGRESS
):
idx, answer = await coro
results[idx] = answer
# Redact secrets before logging
safe_prompt = deepcopy(prompts[idx])
if isinstance(safe_prompt, dict) and "api_key" in safe_prompt:
safe_prompt["api_key"] = "***"
logger.info(f"\nQ{idx}: {str(safe_prompt)[:50]}...\nA{idx}: {answer[:100]}...")
return results
# -----------------------------------------------------------------------------
# Typer Cli
# -----------------------------------------------------------------------------
# @cli.callback()
# def main():
# pass
@cli.command()
def main(
sources: List[str] = typer.Argument(None, help="Prompts or files containing prompts"),
model: str = typer.Option(MODEL, "--model", "-m", help="LiteLLM model name (e.g., 'ollama/gemma3:12b', 'moonshot/kimi-k2-turbo-preview', 'gpt-4')"),
api_base: str = typer.Option(None, "--api-base", help="Override API base URL (auto-detected from model)"),
api_key: str = typer.Option(None, "--api-key", help="Override API key (auto-detected from environment)"),
stdin: bool = typer.Option(False, "--stdin", help="Read prompts from stdin"),
jsonl: bool = typer.Option(False, "--jsonl", help="Input is in JSON Lines format"),
):
"""
Run any combination of prompts via LiteLLM.
Examples\n
--------\n
# Basic usage with default model
litellm_call "What is 2+2?"\n
# Use different models
litellm_call --model "moonshot/kimi-k2-turbo-preview" "Explain quantum physics"\n
litellm_call --model "gpt-4" "Write a poem"\n
litellm_call --model "ollama/llama2" "Hello world"\n
# From files
litellm_call @questions.txt\n
litellm_call prompts.json\n
cat lines.jsonl | litellm_call --stdin --jsonl\n
"""
# Update globals if CLI overrides provided
global MODEL
if model:
MODEL = model
# Set provider-specific overrides if provided
if api_base:
os.environ["OLLAMA_API_BASE"] = api_base
os.environ["MOONSHOT_API_BASE"] = api_base
os.environ["OPENAI_API_BASE"] = api_base
if api_key:
# Determine which provider based on model prefix
if model.startswith("ollama/"):
os.environ["OLLAMA_API_KEY"] = api_key
elif model.startswith("moonshot/"):
os.environ["MOONSHOT_API_KEY"] = api_key
elif model.startswith("gpt") or model.startswith("text-"):
os.environ["OPENAI_API_KEY"] = api_key
elif model.startswith("claude"):
os.environ["ANTHROPIC_API_KEY"] = api_key
prompts: List[Any] = []
# 1) STDIN
if stdin or (sources == ["-"]):
for line in sys.stdin:
line = line.rstrip("\n")
if jsonl:
prompts.append(json.loads(line))
else:
prompts.append(line)
# 2) Positional sources
for src in sources or []:
if src == "-":
continue # already handled above
# @file expansion
if src.startswith("@"):
src = src[1:]
path = Path(src)
if not path.exists():
# Treat literal string
prompts.append(src)
continue
# Decide how to parse the file
if path.suffix.lower() == ".json":
prompts.extend(json.loads(path.read_text()))
elif path.suffix.lower() == ".jsonl" or jsonl:
prompts.extend(json.loads(l) for l in path.read_text().splitlines() if l.strip())
else:
prompts.extend(path.read_text().splitlines())
if not prompts:
typer.echo("No prompts provided.", err=True)
raise typer.Exit(1)
results = asyncio.run(litellm_call(prompts))
for r in results:
typer.echo(r)
# ---------------------------------------------------------------------------
# Quick test
# ---------------------------------------------------------------------------
async def demo() -> List[str]:
"""
Run a canned set of prompts and return the results.
Safe to call from other async code.
"""
prompts = [
"What is the capital of France?",
"Calculate 15+27+38",
"What is 3 + 5? Return JSON: {question:string,answer:number}",
"What is this animal eating? proof_of_concept/ollama_turbo/images/image2.png",
"Describe https://upload.wikimedia.org/wikipedia/commons/thumb/9/90/Labrador_Retriever_portrait.jpg/960px-Labrador_Retriever_portrait.jpg and https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/960px-Cat_November_2010-1a.jpg",
{"text": "Explain this meme", "image": "proof_of_concept/ollama_turbo/images/image.png"},
{
"model": "ollama/gpt-oss:120b",
"api_base": "https://ollama.com",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Tell me a short joke."}
],
"temperature": 1.0
}
]
return await litellm_call(prompts)
def demo_sync() -> List[str]:
"""
Synchronous wrapper around `demo()` for callers that are not async.
"""
return asyncio.run(demo())
if __name__ == "__main__":
# cli()
demo_sync()
"""
Initializes LiteLLM Caching Configuration.
Module: litellm_cache.py
Description: Functions for litellm cache operations
This module sets up LiteLLM's caching mechanism. It attempts to configure'
Redis as the primary cache backend. If Redis is unavailable or fails connection
tests, it falls back to using LiteLLM's built-in in-memory cache. Includes'
a test function to verify cache functionality.
Relevant Documentation:
- LiteLLM Caching: https://docs.litellm.ai/docs/proxy/caching
- Redis Python Client: https://redis.io/docs/clients/python/
- Project Caching Notes: ../../repo_docs/caching_strategy.md (Placeholder)
Input/Output:
- Input: Environment variables for Redis connection (optional).
- Output: Configures LiteLLM global cache settings. Logs status messages.
- The `test_litellm_cache` function demonstrates usage by making cached calls.
"""
# ==============================================================================
# !!! WARNING - DO NOT MODIFY THIS FILE !!!
# ==============================================================================
# This is a core functionality file that other parts of the system depend on.
# Changing this file (especially its sync/async nature) will break multiple dependent systems.
# Any changes here will cascade into test failures across the codebase.
#
# If you think you need to modify this file:
# 1. DON'T. The synchronous implementation is intentional
# 2. The caching system is working as designed
# 3. Test files should adapt to this implementation, not vice versa
# 4. Consult LESSONS_LEARNED.md about not breaking working code
# ==============================================================================
import os
import sys
# from litellm.caching import Cache, Type # Import Cache and Type
import sys # Import sys for exit codes
from typing import Any # Import Any for type hint
import litellm
import redis
from dotenv import load_dotenv # Import dotenv for environment variable loading
from litellm.caching.caching import (
Cache as LiteLLMCache,
) # Import Cache and Type
from litellm.caching.caching import (
LiteLLMCacheType,
)
from loguru import logger
from lean4_prover.utils.log_utils import truncate_large_value
# load_env_file() # Removed - Docker Compose handles .env loading via env_file
load_dotenv()
def initialize_litellm_cache() -> None:
redis_host = os.getenv("REDIS_HOST", "localhost")
redis_port = int(os.getenv("REDIS_PORT", 6379))
redis_password = os.getenv(
"REDIS_PASSWORD", None
) # Assuming password might be needed
try:
logger.debug(
f"Starting LiteLLM cache initialization (Redis target: {redis_host}:{redis_port})..."
)
# Test Redis connection before enabling caching
logger.debug("Testing Redis connection...")
test_redis = redis.Redis(
host=redis_host,
port=redis_port,
password=redis_password,
socket_timeout=2,
decode_responses=True, # Added decode_responses for easier debugging if needed
)
if not test_redis.ping():
raise ConnectionError(
f"Redis is not responding at {redis_host}:{redis_port}."
)
# Verify Redis is empty or log existing keys
keys = test_redis.keys("*")
if keys:
# Use the truncate utility for logging keys
logger.debug(f"Existing Redis keys: {truncate_large_value(keys)}")
else:
logger.debug("Redis cache is empty")
# Set up LiteLLM cache with debug logging
logger.debug("Configuring LiteLLM Redis cache...")
litellm.cache = LiteLLMCache( # Use imported Cache
type=LiteLLMCacheType.REDIS, # Use Enum/Type
host=redis_host,
port=str(redis_port), # Ensure port is a string
password=redis_password,
supported_call_types=["acompletion", "completion"],
ttl=60 * 60 * 24 * 2, # 2 days
)
# Enable caching and verify
logger.debug("Enabling LiteLLM cache...")
litellm.enable_cache()
# Set debug logging for LiteLLM
os.environ["LITELLM_LOG"] = "DEBUG"
# Verify cache configuration
logger.debug(
f"LiteLLM cache config: {litellm.cache.__dict__ if hasattr(litellm.cache, '__dict__') else 'No cache config available'}"
)
logger.info(" Redis caching enabled on localhost:6379")
# Try a test set/get to verify Redis is working
try:
test_key = "litellm_cache_test"
test_redis.set(test_key, "test_value", ex=60)
result = test_redis.get(test_key)
logger.debug(f"Redis test write/read successful: {result == 'test_value'}")
test_redis.delete(test_key)
except Exception as e:
logger.warning(f"Redis test write/read failed: {e}")
except (redis.ConnectionError, redis.TimeoutError) as e:
logger.warning(
f"⚠️ Redis connection failed: {e}. Falling back to in-memory caching."
)
# Fall back to in-memory caching if Redis is unavailable
logger.debug("Configuring in-memory cache fallback...")
litellm.cache = LiteLLMCache(type=LiteLLMCacheType.LOCAL) # Use Enum/Type
litellm.enable_cache()
logger.debug("In-memory cache enabled")
def test_litellm_cache() -> tuple[bool, dict[str, bool | None]]:
"""
Test the LiteLLM cache functionality with a sample completion call.
Returns a tuple: (overall_success, cache_hit_details)
"""
initialize_litellm_cache()
test_success = False
# Explicitly annotate the dictionary type
cache_details: dict[str, bool | None] = {"cache_hit1": None, "cache_hit2": None}
try:
# Test the cache with a simple completion call
test_messages = [
{
"role": "user",
"content": "What is the capital of France? Respond concisely.",
}
]
logger.info("Testing cache with completion call...")
# First call - Expect cache miss (cache_hit should be False or None)
response1 = litellm.completion(
model=os.getenv("LITELLM_TEST_MODEL", os.getenv("LEAN4_MODEL")), # Use test model or main model
messages=test_messages,
# Ensure caching is attempted, remove specific cache param unless needed for override
)
usage1 = getattr(response1, "usage", "N/A")
hidden_params1 = getattr(response1, "_hidden_params", {})
cache_hit1 = hidden_params1.get(
"cache_hit"
) # Could be None if not hit or feature disabled
cache_details["cache_hit1"] = cache_hit1
logger.info(f"First call usage: {usage1}")
logger.info(f"Response 1: Cache hit: {cache_hit1}")
# Check if first call was NOT a cache hit (allow None or False)
first_call_missed = cache_hit1 is None or cache_hit1 is False
# Second call - Expect cache hit (cache_hit should be True)
response2 = litellm.completion(
model=os.getenv("LITELLM_TEST_MODEL", os.getenv("LEAN4_MODEL")), # Use test model or main model
messages=test_messages,
# Ensure caching is attempted
)
usage2 = getattr(response2, "usage", "N/A")
hidden_params2 = getattr(response2, "_hidden_params", {})
cache_hit2 = hidden_params2.get("cache_hit")
cache_details["cache_hit2"] = cache_hit2
logger.info(f"Second call usage: {usage2}")
logger.info(f"Response 2: Cache hit: {cache_hit2}")
# Check if second call WAS a cache hit
second_call_hit = cache_hit2 is True
# Determine overall test success
test_success = first_call_missed and second_call_hit
except Exception as e:
logger.error(f"Cache test failed during execution: {e}")
test_success = False # Ensure failure is recorded
return test_success, cache_details
if __name__ == "__main__":
logger.info("--- Running Standalone Validation for initialize_litellm_cache.py ---")
tests_passed_count = 0
tests_failed_count = 0
total_tests = 1
try:
test_result, details = test_litellm_cache()
if test_result:
tests_passed_count += 1
logger.success(" Test 'cache_hit_miss': PASSED")
else:
tests_failed_count += 1
logger.error(" Test 'cache_hit_miss': FAILED")
logger.error(
" Expected first call cache_hit=False/None, second call cache_hit=True."
)
logger.error(
f" Got: cache_hit1={details.get('cache_hit1')}, cache_hit2={details.get('cache_hit2')}"
)
except Exception as e:
tests_failed_count += 1 # Count exception as failure
logger.error(
" Test 'cache_hit_miss': FAILED due to exception during test execution."
)
logger.error(f" Exception: {e}", exc_info=True)
# --- Report validation status ---
print("\n--- Test Summary ---")
print(f"Total Tests: {total_tests}")
print(f"Passed: {tests_passed_count}")
print(f"Failed: {tests_failed_count}")
if tests_failed_count == 0:
print("\n VALIDATION COMPLETE - All LiteLLM cache tests passed.")
sys.exit(0)
else:
print("\n VALIDATION FAILED - LiteLLM cache test failed.")
# Error details already logged above
sys.exit(1)
#!/usr/bin/env python3
"""
LiteLLM Call - Easy async LLM batch runner with automatic image support
WHAT IT DOES:
- Run multiple LLM prompts in parallel for speed
- Automatically detects and includes images from URLs or local files
- Works with any LiteLLM-supported model (OpenAI, Anthropic, Ollama, etc.)
- Handles all image processing automatically (compression, base64 encoding)
QUICK START:
1. Basic text prompt:
$ python litellm_call.py "What is 2+2?"
2. Multiple prompts (run in parallel):
$ python litellm_call.py "What is 2+2?" "What is the capital of France?"
3. Prompt with images (auto-detected):
$ python litellm_call.py "What's in this image? /path/to/image.jpg"
$ python litellm_call.py "Compare: https://example.com/cat.jpg and dog.png"
4. From files:
$ python litellm_call.py @prompts.txt # One prompt per line
$ python litellm_call.py prompts.json # JSON array of prompts
$ python litellm_call.py @prompts.jsonl # JSON Lines format
5. From stdin:
$ echo "What is 2+2?" | python litellm_call.py --stdin
$ cat prompts.jsonl | python litellm_call.py --stdin --jsonl
ENVIRONMENT SETUP:
- OLLAMA_DEFAULT_MODEL: Model to use (default: "ollama/gemma3:12b")
- OLLAMA_BASE_URL: API endpoint (default: "http://localhost:11434")
- OLLAMA_API_KEY: API key if required
ADVANCED USAGE:
- Override model: --model "gpt-4"
- Custom API: --api-base "https://api.openai.com/v1"
- With API key: --api-key "sk-..."
INPUT FORMATS:
1. Simple string: "What is 2+2?"
2. With image: {"text": "Explain this", "image": "path/to/image.jpg"}
3. Full control: {"model": "gpt-4", "messages": [...], "temperature": 0.7}
FEATURES:
- Automatic image detection in prompts (URLs and file paths)
- Smart image compression to stay under API limits
- Parallel processing with progress bar
- Automatic retries on failures
- Silent handling of missing/broken images
- Supports all common image formats (jpg, png, gif, etc.)
"""
import asyncio
import sys
import json
import base64
import io
import os
import re
from pathlib import Path
from typing import List, Tuple, Any, Dict
import httpx
from PIL import Image
from litellm import acompletion
from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm.asyncio import tqdm
from loguru import logger
from dotenv import load_dotenv, find_dotenv
from urlextract import URLExtract
import typer
from strip_tags import strip_tags
logger.remove()
logger.add(sys.stderr, level="WARNING")
from lean4_prover.utils.litellm_cache import initialize_litellm_cache
load_dotenv(find_dotenv())
initialize_litellm_cache()
# -----------------------------------------------------------------------------
# Typer app (NEW)
# -----------------------------------------------------------------------------
cli = typer.Typer(
name="litellm_call",
help="Fast async LLM batch runner with inline image support via LiteLLM / Ollama.",
)
MODEL = os.getenv("OLLAMA_DEFAULT_MODEL", "ollama/gemma3:12b")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY", "")
IMAGE_EXT = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"}
extractor = URLExtract()
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def safe_image(path: Path) -> bool:
"""True if file exists, has an image extension, and PIL can open it."""
try:
return path.exists() and path.suffix.lower() in IMAGE_EXT and Image.open(path).verify() is None
except Exception:
return False
def extract_images(text: str) -> tuple[List[str], str]:
"""
Return:
- list[str] of all valid image URLs/paths (remote & local)
- cleaned prompt text with placeholders {Image 1}, {Image 2}, …
"""
found, seen = [], set()
# 1) Strip XML/HTML tags ----------------------------------------------------
plain = strip_tags(text)
# 2) Remote URLs -----------------------------------------------------------
for url in extractor.find_urls(plain):
url = url.strip()
if url.lower().endswith(tuple(IMAGE_EXT)) and url not in seen:
found.append(url)
seen.add(url)
# 3) Local files -----------------------------------------------------------
tokens = re.findall(r'(?:"[^"]*"|\'[^\']*\'|\S+)', plain)
for tok in tokens:
tok = tok.strip('"\'')
if not tok:
continue
candidate = Path(tok).expanduser().resolve()
if safe_image(candidate) and str(candidate) not in seen:
found.append(str(candidate))
seen.add(str(candidate))
# 4) Build cleaned prompt with placeholders --------------------------------
cleaned = text
for idx, img in enumerate(found, 1):
placeholder = f"{{Image {idx}}}"
cleaned = cleaned.replace(img, placeholder)
cleaned = re.sub(r"\s{2,}", " ", cleaned).strip()
return found, cleaned
def compress_image(path_str: str, max_kb: int = 1000) -> str:
"""Return base-64 data-URI for a *local* image, compressed if required."""
path = Path(path_str)
img_bytes = path.read_bytes()
max_bytes = max_kb * 1024
if len(img_bytes) <= max_bytes:
mime = f"image/{path.suffix[1:]}"
return f"data:{mime};base64,{base64.b64encode(img_bytes).decode()}"
img = Image.open(io.BytesIO(img_bytes))
quality = 85
while quality > 20:
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=quality, optimize=True)
if len(buf.getvalue()) <= max_bytes:
return f"data:image/jpeg;base64,{base64.b64encode(buf.getvalue()).decode()}"
quality -= 10
img.thumbnail((img.width // 2, img.height // 2))
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=30)
return f"data:image/jpeg;base64,{base64.b64encode(buf.getvalue()).decode()}"
def fetch_remote_image(url: str) -> str | None:
"""Download remote image and return base-64 data-URI or None on failure."""
try:
r = httpx.get(url, timeout=10)
r.raise_for_status()
mime = r.headers.get("content-type", "image/jpeg").split(";")[0]
return f"data:{mime};base64,{base64.b64encode(r.content).decode()}"
except Exception as e:
logger.warning(f"Skipping remote image {url}: {e}")
return None
# -----------------------------------------------------------------------------
# LLM call with retry
# -----------------------------------------------------------------------------
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=4))
async def _call(params: Dict[str, Any], idx: int) -> Tuple[int, str]:
resp = await acompletion(**params)
return idx, resp.choices[0].message.content
# -----------------------------------------------------------------------------
# Batch runner
# -----------------------------------------------------------------------------
async def litellm_call(prompts: List[Any]) -> List[str]:
"""Run prompts: strings, dicts, or raw LiteLLM dicts."""
# Accept a single prompt as well
if isinstance(prompts, (str, dict)):
prompts = [prompts]
tasks: List[asyncio.Task] = []
for idx, item in enumerate(prompts):
# Raw LiteLLM dict
if isinstance(item, dict) and "messages" in item:
defaults = {
"model": MODEL,
"api_base": OLLAMA_BASE_URL,
"api_key": OLLAMA_API_KEY,
}
params = {**defaults, **item} # user keys override env
tasks.append(asyncio.create_task(_call(params, idx)))
continue
# Parse text & images
if isinstance(item, dict):
text = str(item.get("text", ""))
images = [str(item["image"])] if "image" in item else []
else:
images, text = extract_images(str(item)) # returns (list[str], cleaned)
# Build OpenAI-style content
content_parts: List[Dict[str, Any]] = [{"type": "text", "text": text}]
for img in images:
if img.startswith("http"):
url = fetch_remote_image(img)
if url is None:
continue # skip unreachable
else:
url = compress_image(img)
content_parts.append({"type": "image_url", "image_url": {"url": url}})
params = {
"model": MODEL,
"api_base": OLLAMA_BASE_URL,
"api_key": OLLAMA_API_KEY,
"messages": [{"role": "user", "content": content_parts}],
}
tasks.append(asyncio.create_task(_call(params, idx)))
results = [None] * len(tasks)
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing"):
idx, answer = await coro
results[idx] = answer
logger.info(f"\nQ{idx}: {str(prompts[idx])[:50]}...\nA{idx}: {answer[:100]}...")
return results
# -----------------------------------------------------------------------------
# Typer Cli
# -----------------------------------------------------------------------------
# @cli.callback()
# def main():
# pass
@cli.command()
def main(
sources: List[str] = typer.Argument(None, help="…"),
model: str = typer.Option(MODEL, "--model", "-m", help="…"),
api_base: str = typer.Option(OLLAMA_BASE_URL, "--api-base", help="…"),
api_key: str = typer.Option(OLLAMA_API_KEY, "--api-key", help="…"),
stdin: bool = typer.Option(False, "--stdin", help="…"),
jsonl: bool = typer.Option(False, "--jsonl", help="…"),
):
"""
Run any combination of prompts via LiteLLM.
Examples\n
--------\n
litellm_call run "What is 2+2?"\n
litellm_call run @questions.txt\n
litellm_call run prompts.json\n
cat lines.jsonl | litellm_call run --stdin --jsonl\n
"""
print(">>> typer sources =", sources)
prompts: List[Any] = []
# 1) STDIN
if stdin or (sources == ["-"]):
for line in sys.stdin:
line = line.rstrip("\n")
if jsonl:
prompts.append(json.loads(line))
else:
prompts.append(line)
# 2) Positional sources
for src in sources or []:
if src == "-":
continue # already handled above
# @file expansion
if src.startswith("@"):
src = src[1:]
path = Path(src)
if not path.exists():
# Treat literal string
prompts.append(src)
continue
# Decide how to parse the file
if path.suffix.lower() == ".json":
prompts.extend(json.loads(path.read_text()))
elif path.suffix.lower() == ".jsonl" or jsonl:
prompts.extend(json.loads(l) for l in path.read_text().splitlines() if l.strip())
else:
prompts.extend(path.read_text().splitlines())
if not prompts:
typer.echo("No prompts provided.", err=True)
raise typer.Exit(1)
# TEMPORARY DEBUG
print("*** PROMPT TO MODEL:", prompts)
results = asyncio.run(litellm_call(prompts))
for r in results:
typer.echo(r)
# ---------------------------------------------------------------------------
# Quick test
# ---------------------------------------------------------------------------
def debug() -> None:
prompts = [
"What is the capital of France?",
"Calculate 15+27+38",
"What is 3 + 5? Return JSON: {question:string,answer:number}",
"What is this animal eating? proof_of_concept/ollama_turbo/images/image2.png",
"Describe https://upload.wikimedia.org/wikipedia/commons/thumb/9/90/Labrador_Retriever_portrait.jpg/960px-Labrador_Retriever_portrait.jpg and https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/960px-Cat_November_2010-1a.jpg",
{"text": "Explain this meme", "image": "proof_of_concept/ollama_turbo/images/image.png"},
{
"model": "ollama/gpt-oss:120b",
"api_base": "https://ollama.com",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Tell me a short joke."}
],
"temperature": 1.0
}
]
results = asyncio.run(litellm_call(prompts))
print("\nFinal Results:")
for i, r in enumerate(results, 1):
print(f"{i}. {r}")
DEBUG=False
if __name__ == "__main__":
# print(f"DEBUG -- sys.argv received by script: {sys.argv}")
debug() if DEBUG else cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment