Skip to content

Instantly share code, notes, and snippets.

@titu1994
Created April 16, 2025 05:41
Show Gist options
  • Save titu1994/0274b76e6ab9b85deee1236a3b7af7b2 to your computer and use it in GitHub Desktop.
Save titu1994/0274b76e6ab9b85deee1236a3b7af7b2 to your computer and use it in GitHub Desktop.
import google.genai as genai
# Correct import for genai types
from google.genai import types as genai_types
from google.genai.errors import APIError
from google.api_core import exceptions as api_core_exceptions
from httpx import ReadTimeout
import time
import os
import json
import logging
from pathlib import Path
from typing import List, Dict, Optional, Union, Any, Literal
# --- Configuration ---
# Basic logging setup
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
# --- Constants ---
# Status values for JSONL output
STATUS_PENDING = "pending"
STATUS_SUCCESSFUL = "successful"
STATUS_FAILED = "failed"
# Errors considered potentially temporary and worth retrying
# 500 Internal Server Error, 503 Service Unavailable, 429 Resource Exhausted (Rate Limiting)
# DeadlineExceeded is often the specific exception for timeouts.
RETRYABLE_API_ERROR_CODES = (500, 503, 429)
# Using a specific exception class for timeouts is more robust
RETRYABLE_EXCEPTIONS = (
api_core_exceptions.DeadlineExceeded, # Explicit timeout from api_core
api_core_exceptions.ServiceUnavailable, # 503 from api_core
api_core_exceptions.ResourceExhausted, # 429 from api_core
api_core_exceptions.InternalServerError, # 500 from api_core
ReadTimeout # httpx.ReadTimeout for network-level timeouts
# Consider adding httpx.ConnectTimeout if you see connection errors
)
# --- Type Definitions ---
PromptState = Dict[str, Union[str, Optional[str]]] # {"prompt": str, "status": str, "result": Optional[str]}
StateDict = Dict[str, PromptState] # Maps prompt string to its state dictionary
# --- Helper Functions ---
def _save_state_to_jsonl(filepath: Path, state: StateDict):
"""Saves the current state dictionary to a JSONL file, overwriting the existing file."""
try:
filepath.parent.mkdir(parents=True, exist_ok=True)
with open(filepath, 'w', encoding='utf-8') as f:
for prompt, data in state.items():
# Ensure the structure matches the defined PromptState
line_data = {
"prompt": prompt, # Prompt is the key in StateDict, ensure it's in the line
"status": data.get("status", STATUS_PENDING), # Default to pending if missing
"result": data.get("result", None) # Default to None if missing
}
json.dump(line_data, f, ensure_ascii=False)
f.write('\n')
logging.debug(f"State ({len(state)} entries) saved to {filepath}")
except IOError as e:
logging.error(f"Failed to save state to {filepath}: {e}")
except TypeError as e:
logging.error(f"Failed to serialize state to JSONL: {e} - Problematic data: {state.get(prompt, {})}") # Log which entry might be problematic
def _load_state_from_jsonl(filepath: Path) -> StateDict:
"""
Loads processing state from a JSONL file.
Returns a dictionary mapping prompts to their state information.
Handles file not found, empty file, and invalid JSON lines gracefully.
"""
state: StateDict = {}
if not filepath.exists() or filepath.stat().st_size == 0:
logging.info(f"No existing state file found at {filepath} or file is empty. Starting fresh.")
return state
try:
with open(filepath, 'r', encoding='utf-8') as f:
line_num = 0
for line in f:
line_num += 1
line = line.strip()
if not line:
continue # Skip empty lines
try:
data = json.loads(line)
if isinstance(data, dict) and "prompt" in data and "status" in data:
# Basic validation
prompt = data["prompt"]
status = data["status"]
result = data.get("result") # Result might be null/None
if not isinstance(prompt, str) or not isinstance(status, str):
logging.warning(f"Skipping line {line_num} in {filepath}: Invalid 'prompt' or 'status' type.")
continue
# Allow only defined statuses or mark as pending
if status not in [STATUS_PENDING, STATUS_SUCCESSFUL, STATUS_FAILED]:
logging.warning(f"Skipping line {line_num} in {filepath}: Invalid status '{status}'. Treating as pending.")
status = STATUS_PENDING # Or could choose to skip completely
result = None # Reset result if status is invalid
# Store the state using the prompt as the key
state[prompt] = {"prompt": prompt, "status": status, "result": result}
else:
logging.warning(f"Skipping line {line_num} in {filepath}: Missing 'prompt' or 'status' key, or not a JSON object.")
except json.JSONDecodeError:
logging.warning(f"Skipping line {line_num} in {filepath}: Invalid JSON.")
except Exception as inner_e: # Catch other potential errors during line processing
logging.error(f"Unexpected error processing line {line_num} in {filepath}: {inner_e}")
logging.info(f"Loaded {len(state)} existing prompt states from {filepath}")
return state
except IOError as e:
logging.error(f"Could not read state file {filepath}: {e}. Starting fresh.")
return {} # Return empty state on read error
except Exception as e:
logging.error(f"Unexpected error loading state from {filepath}: {e}. Starting fresh.")
return {}
def _initialize_genai_client(
api_key: Optional[str] = None,
vertex_ai_config: Optional[Dict[str, str]] = None,
http_options: Optional[genai_types.HttpOptions] = None
) -> genai.Client:
"""Initializes and returns the google.genai Client."""
use_vertex = False
project_id = None
location = None
# Prioritize vertex_ai_config dictionary
if vertex_ai_config:
project_id = vertex_ai_config.get('project')
location = vertex_ai_config.get('location')
if project_id and location:
use_vertex = True
logging.info(f"Using Vertex AI configuration: project='{project_id}', location='{location}'")
else:
raise ValueError("Vertex AI config requires both 'project' and 'location' keys.")
# Check environment variables if not configured via dict
elif os.getenv('GOOGLE_GENAI_USE_VERTEXAI', 'false').lower() == 'true':
project_id = os.getenv('GOOGLE_CLOUD_PROJECT')
location = os.getenv('GOOGLE_CLOUD_LOCATION')
if project_id and location:
use_vertex = True
logging.info(f"Using Vertex AI via environment variables: project='{project_id}', location='{location}'")
else:
raise ValueError("Vertex AI requires GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION env vars when GOOGLE_GENAI_USE_VERTEXAI is true.")
# Fallback to Gemini Developer API Key
if not use_vertex:
api_key = api_key or os.getenv('GOOGLE_API_KEY')
if not api_key:
raise ValueError("API key must be provided either via argument or GOOGLE_API_KEY environment variable for Gemini Developer API.")
logging.info("Using Gemini Developer API Key.")
return genai.Client(api_key=api_key, http_options=http_options)
else:
# Vertex AI Initialization
return genai.Client(
vertexai=True,
project=project_id,
location=location,
http_options=http_options
)
# --- Main Entrypoint Function ---
def query_llm_with_retry(
prompts: List[str],
output_filepath: Union[str, Path],
model_name: str = "gemini-2.5-flash-001", # Default to a generally available model
api_key: Optional[str] = None,
vertex_ai_config: Optional[Dict[str, str]] = None, # e.g., {'project': 'your-gcp-project', 'location': 'us-central1'}
max_retries: int = 3,
retry_delay_seconds: int = 5,
api_timeout_seconds: int = 120,
generation_config_dict: Optional[Dict[str, Any]] = None, # e.g., {"temperature": 0.7, "max_output_tokens": 1024}
safety_settings: Optional[List[Dict[str, Any]]] = None, # e.g., [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'threshold': 'BLOCK_NONE'}]
system_instruction: Optional[str] = None,
thinking_config: Optional[Dict[str, Any]] = None # e.g., {"temperature": 0.7, "max_output_tokens": 1024}
) -> StateDict:
"""
Queries a Google Generative AI model for a list of prompts with retries and state persistence.
Loads existing results from `output_filepath`, processes only prompts not yet successfully
completed, and saves results incrementally.
Args:
prompts: A list of unique input strings (prompts) to send to the LLM.
**Important**: Prompts should be unique identifiers for tasks.
output_filepath: Path to the JSON file where results are stored and loaded from.
model_name: The name of the Google Generative AI model to use.
api_key: Your Google AI API key (required if not using Vertex AI).
Can also be set via the GOOGLE_API_KEY environment variable.
vertex_ai_config: Dictionary with 'project' and 'location' for Vertex AI.
Alternatively, set GOOGLE_GENAI_USE_VERTEXAI=true,
GOOGLE_CLOUD_PROJECT, and GOOGLE_CLOUD_LOCATION env vars.
max_retries: Maximum number of retries for timeout or specific API errors.
retry_delay_seconds: Delay in seconds between retries.
api_timeout_seconds: Timeout for the API call in seconds.
generation_config_dict: Optional dictionary for generation parameters
(e.g., temperature, max_output_tokens).
safety_settings: Optional list of safety settings dictionaries.
system_instruction: Optional system instruction string.
thinking_config: Optional dictionary for thinking configuration parameters.
Returns:
A dictionary mapping each input prompt (from the original list) to its
final state dictionary (`{"prompt": ..., "status": ..., "result": ...}`).
This dictionary reflects the state saved in `output_filepath`.
"""
output_path = Path(output_filepath)
# --- State Loading ---
# Load existing state. Keys are prompts, values are state dictionaries.
current_state = _load_state_from_jsonl(output_path)
# --- Initialize potentially missing prompts in results ---
# This ensures that even if the script fails before processing a prompt,
# its key exists in the final returned dictionary, mapped to None.
# It also helps track which prompts from the *current* list need processing.
prompts_to_process = []
initial_save_needed = False # Flag if we add new pending prompts
for prompt in prompts:
if prompt not in current_state:
# New prompt, add to state as pending
current_state[prompt] = {"prompt": prompt, "status": STATUS_PENDING, "result": None}
prompts_to_process.append(prompt)
initial_save_needed = True
logging.debug(f"Prompt '{prompt[:50]}...' added as new and pending.")
elif current_state[prompt]["status"] in [STATUS_PENDING, STATUS_FAILED]:
# Existing prompt that is pending or failed, needs processing
# Ensure status is PENDING before processing if it was failed before
if current_state[prompt]["status"] == STATUS_FAILED:
logging.info(f"Retrying previously failed prompt: '{prompt[:50]}...'")
# Reset status to pending for the retry attempt. Keep old error in result for now.
# current_state[prompt]["status"] = STATUS_PENDING # Let the processing loop update status
prompts_to_process.append(prompt)
# else: prompt already has status STATUS_SUCCESSFUL, skip it.
if initial_save_needed:
# Save the state immediately if new prompts were added as pending
_save_state_to_jsonl(output_path, current_state)
if not prompts_to_process:
logging.info("All provided prompts are already marked as 'successful' in the state file. Nothing new to process.")
return current_state # Return the loaded/updated state
logging.info(f"Found {len(prompts_to_process)} prompts to process (pending or previously failed) out of {len(prompts)} total provided.")
# --- Client and Config Initialization ---
# Configure HTTP options with timeout
api_timeout_milliseconds = api_timeout_seconds * 1000
http_opts = genai_types.HttpOptions(timeout=api_timeout_milliseconds)
# Initialize the GenAI client
try:
client = _initialize_genai_client(api_key, vertex_ai_config, http_opts)
except ValueError as e:
logging.error(f"Client Initialization Error: {e}. Cannot proceed.")
# Mark all prompts intended for processing in this run as failed if client init fails
for prompt in prompts_to_process:
if current_state[prompt]["status"] == STATUS_PENDING: # Only update if still pending
current_state[prompt]["status"] = STATUS_FAILED
current_state[prompt]["result"] = "Error: Client initialization failed"
_save_state_to_jsonl(output_path, current_state) # Save the updated failure status
return current_state # Return the state reflecting the init failure
# Construct GenerateContentConfig if parameters are provided
generate_content_config_args = {}
if generation_config_dict:
filtered_gen_config = {k: v for k, v in generation_config_dict.items() if v is not None}
if filtered_gen_config:
generate_content_config_args.update(filtered_gen_config)
if safety_settings:
generate_content_config_args['safety_settings'] = safety_settings
if system_instruction:
generate_content_config_args['system_instruction'] = system_instruction
if thinking_config:
generate_content_config_args['thinking_config'] = thinking_config
llm_config = genai_types.GenerateContentConfig(**generate_content_config_args) if generate_content_config_args else None
# --- Process Prompts ---
total_to_process = len(prompts_to_process)
newly_processed_count = 0
newly_failed_count = 0
for i, prompt in enumerate(prompts_to_process):
# Find original index for logging consistency if needed (optional)
try:
original_index = prompts.index(prompt) + 1
log_prefix = f"Prompt ({i+1}/{total_to_process}, Original #{original_index}/{len(prompts)})"
except ValueError: # Should not happen if prompts_to_process comes from prompts
log_prefix = f"Prompt ({i+1}/{total_to_process})"
logging.info(f"{log_prefix}: Querying model for: '{prompt[:80]}...'")
current_retries = 0
# Status updates will modify current_state[prompt] directly
prompt_data = current_state[prompt] # Get reference to the state dict for this prompt
prompt_data["status"] = STATUS_PENDING # Ensure it's marked as pending for this attempt
while current_retries <= max_retries:
request_successful = False
permanent_error = False
error_message = None
try:
response = client.models.generate_content(
model=model_name,
contents=prompt,
config=llm_config
)
# --- Process Response ---
# Check for blocked content first
if response.prompt_feedback and response.prompt_feedback.block_reason:
reason = response.prompt_feedback.block_reason.name
message = response.prompt_feedback.block_reason_message or "No specific message."
error_message = f"Error: Blocked (prompt) - {reason}. Message: {message}"
logging.warning(f"{log_prefix}: Failed - {error_message}")
permanent_error = True
break # Exit retry loop for prompt blocking
# Check if response has text
try:
response_text = response.text # Handles candidate checking internally
# Try extracting reasoning
if response.candidates and response.candidates[0].content.parts[0].thought:
reasoning = response.candidates[0].content.parts[0].thought
logging.info(f"{log_prefix}: Reasoning: {reasoning}")
response_text = f"<think>{reasoning}</think>\n\n{response_text}"
logging.info(f"{log_prefix}: Successfully received response.")
prompt_data["status"] = STATUS_SUCCESSFUL
prompt_data["result"] = response_text
request_successful = True
break # Exit retry loop on success
except ValueError as e:
# Handle cases where .text property fails (e.g., no valid candidate, safety block on response)
finish_reason_str = "Unknown reason"
finish_message = ""
block_reason_str = "Unknown reason"
response_blocked = False
if response.candidates:
cand = response.candidates[0]
if cand.finish_reason: finish_reason_str = cand.finish_reason.name
if cand.finish_message: finish_message = f" Message: {cand.finish_message}"
if cand.safety_ratings:
for rating in cand.safety_ratings:
if rating.blocked:
block_reason_str = rating.category.name
error_message = f"Error: Blocked (response) due to {block_reason_str}."
logging.warning(f"{log_prefix}: Failed - {error_message}")
response_blocked = True
permanent_error = True
break # Break inner safety check loop
if response_blocked: break # Break retry loop if response blocked
# If not blocked, log the general extraction failure
error_message = f"Error: Could not extract text. Finish Reason: {finish_reason_str}{finish_message}. Details: {e}"
logging.warning(f"{log_prefix}: Failed - {error_message}")
permanent_error = True # Consider this permanent
break # Exit retry loop
# --- Error Handling & Retries ---
except RETRYABLE_EXCEPTIONS as e:
current_retries += 1
error_message = f"Retryable error (Attempt {current_retries}/{max_retries}): {type(e).__name__} - {e}"
logging.warning(f"{log_prefix}: {error_message}")
if current_retries > max_retries:
error_message = f"Error: Failed after {max_retries} retries due to {type(e).__name__}."
logging.error(f"{log_prefix}: {error_message}")
permanent_error = True # Exhausted retries
break
else:
logging.info(f"Waiting {retry_delay_seconds}s before retrying...")
time.sleep(retry_delay_seconds)
# Loop continues automatically
except APIError as e:
error_code = e.code
is
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment