Created
April 16, 2025 05:41
-
-
Save titu1994/0274b76e6ab9b85deee1236a3b7af7b2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import google.genai as genai | |
# Correct import for genai types | |
from google.genai import types as genai_types | |
from google.genai.errors import APIError | |
from google.api_core import exceptions as api_core_exceptions | |
from httpx import ReadTimeout | |
import time | |
import os | |
import json | |
import logging | |
from pathlib import Path | |
from typing import List, Dict, Optional, Union, Any, Literal | |
# --- Configuration --- | |
# Basic logging setup | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s') | |
# --- Constants --- | |
# Status values for JSONL output | |
STATUS_PENDING = "pending" | |
STATUS_SUCCESSFUL = "successful" | |
STATUS_FAILED = "failed" | |
# Errors considered potentially temporary and worth retrying | |
# 500 Internal Server Error, 503 Service Unavailable, 429 Resource Exhausted (Rate Limiting) | |
# DeadlineExceeded is often the specific exception for timeouts. | |
RETRYABLE_API_ERROR_CODES = (500, 503, 429) | |
# Using a specific exception class for timeouts is more robust | |
RETRYABLE_EXCEPTIONS = ( | |
api_core_exceptions.DeadlineExceeded, # Explicit timeout from api_core | |
api_core_exceptions.ServiceUnavailable, # 503 from api_core | |
api_core_exceptions.ResourceExhausted, # 429 from api_core | |
api_core_exceptions.InternalServerError, # 500 from api_core | |
ReadTimeout # httpx.ReadTimeout for network-level timeouts | |
# Consider adding httpx.ConnectTimeout if you see connection errors | |
) | |
# --- Type Definitions --- | |
PromptState = Dict[str, Union[str, Optional[str]]] # {"prompt": str, "status": str, "result": Optional[str]} | |
StateDict = Dict[str, PromptState] # Maps prompt string to its state dictionary | |
# --- Helper Functions --- | |
def _save_state_to_jsonl(filepath: Path, state: StateDict): | |
"""Saves the current state dictionary to a JSONL file, overwriting the existing file.""" | |
try: | |
filepath.parent.mkdir(parents=True, exist_ok=True) | |
with open(filepath, 'w', encoding='utf-8') as f: | |
for prompt, data in state.items(): | |
# Ensure the structure matches the defined PromptState | |
line_data = { | |
"prompt": prompt, # Prompt is the key in StateDict, ensure it's in the line | |
"status": data.get("status", STATUS_PENDING), # Default to pending if missing | |
"result": data.get("result", None) # Default to None if missing | |
} | |
json.dump(line_data, f, ensure_ascii=False) | |
f.write('\n') | |
logging.debug(f"State ({len(state)} entries) saved to {filepath}") | |
except IOError as e: | |
logging.error(f"Failed to save state to {filepath}: {e}") | |
except TypeError as e: | |
logging.error(f"Failed to serialize state to JSONL: {e} - Problematic data: {state.get(prompt, {})}") # Log which entry might be problematic | |
def _load_state_from_jsonl(filepath: Path) -> StateDict: | |
""" | |
Loads processing state from a JSONL file. | |
Returns a dictionary mapping prompts to their state information. | |
Handles file not found, empty file, and invalid JSON lines gracefully. | |
""" | |
state: StateDict = {} | |
if not filepath.exists() or filepath.stat().st_size == 0: | |
logging.info(f"No existing state file found at {filepath} or file is empty. Starting fresh.") | |
return state | |
try: | |
with open(filepath, 'r', encoding='utf-8') as f: | |
line_num = 0 | |
for line in f: | |
line_num += 1 | |
line = line.strip() | |
if not line: | |
continue # Skip empty lines | |
try: | |
data = json.loads(line) | |
if isinstance(data, dict) and "prompt" in data and "status" in data: | |
# Basic validation | |
prompt = data["prompt"] | |
status = data["status"] | |
result = data.get("result") # Result might be null/None | |
if not isinstance(prompt, str) or not isinstance(status, str): | |
logging.warning(f"Skipping line {line_num} in {filepath}: Invalid 'prompt' or 'status' type.") | |
continue | |
# Allow only defined statuses or mark as pending | |
if status not in [STATUS_PENDING, STATUS_SUCCESSFUL, STATUS_FAILED]: | |
logging.warning(f"Skipping line {line_num} in {filepath}: Invalid status '{status}'. Treating as pending.") | |
status = STATUS_PENDING # Or could choose to skip completely | |
result = None # Reset result if status is invalid | |
# Store the state using the prompt as the key | |
state[prompt] = {"prompt": prompt, "status": status, "result": result} | |
else: | |
logging.warning(f"Skipping line {line_num} in {filepath}: Missing 'prompt' or 'status' key, or not a JSON object.") | |
except json.JSONDecodeError: | |
logging.warning(f"Skipping line {line_num} in {filepath}: Invalid JSON.") | |
except Exception as inner_e: # Catch other potential errors during line processing | |
logging.error(f"Unexpected error processing line {line_num} in {filepath}: {inner_e}") | |
logging.info(f"Loaded {len(state)} existing prompt states from {filepath}") | |
return state | |
except IOError as e: | |
logging.error(f"Could not read state file {filepath}: {e}. Starting fresh.") | |
return {} # Return empty state on read error | |
except Exception as e: | |
logging.error(f"Unexpected error loading state from {filepath}: {e}. Starting fresh.") | |
return {} | |
def _initialize_genai_client( | |
api_key: Optional[str] = None, | |
vertex_ai_config: Optional[Dict[str, str]] = None, | |
http_options: Optional[genai_types.HttpOptions] = None | |
) -> genai.Client: | |
"""Initializes and returns the google.genai Client.""" | |
use_vertex = False | |
project_id = None | |
location = None | |
# Prioritize vertex_ai_config dictionary | |
if vertex_ai_config: | |
project_id = vertex_ai_config.get('project') | |
location = vertex_ai_config.get('location') | |
if project_id and location: | |
use_vertex = True | |
logging.info(f"Using Vertex AI configuration: project='{project_id}', location='{location}'") | |
else: | |
raise ValueError("Vertex AI config requires both 'project' and 'location' keys.") | |
# Check environment variables if not configured via dict | |
elif os.getenv('GOOGLE_GENAI_USE_VERTEXAI', 'false').lower() == 'true': | |
project_id = os.getenv('GOOGLE_CLOUD_PROJECT') | |
location = os.getenv('GOOGLE_CLOUD_LOCATION') | |
if project_id and location: | |
use_vertex = True | |
logging.info(f"Using Vertex AI via environment variables: project='{project_id}', location='{location}'") | |
else: | |
raise ValueError("Vertex AI requires GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION env vars when GOOGLE_GENAI_USE_VERTEXAI is true.") | |
# Fallback to Gemini Developer API Key | |
if not use_vertex: | |
api_key = api_key or os.getenv('GOOGLE_API_KEY') | |
if not api_key: | |
raise ValueError("API key must be provided either via argument or GOOGLE_API_KEY environment variable for Gemini Developer API.") | |
logging.info("Using Gemini Developer API Key.") | |
return genai.Client(api_key=api_key, http_options=http_options) | |
else: | |
# Vertex AI Initialization | |
return genai.Client( | |
vertexai=True, | |
project=project_id, | |
location=location, | |
http_options=http_options | |
) | |
# --- Main Entrypoint Function --- | |
def query_llm_with_retry( | |
prompts: List[str], | |
output_filepath: Union[str, Path], | |
model_name: str = "gemini-2.5-flash-001", # Default to a generally available model | |
api_key: Optional[str] = None, | |
vertex_ai_config: Optional[Dict[str, str]] = None, # e.g., {'project': 'your-gcp-project', 'location': 'us-central1'} | |
max_retries: int = 3, | |
retry_delay_seconds: int = 5, | |
api_timeout_seconds: int = 120, | |
generation_config_dict: Optional[Dict[str, Any]] = None, # e.g., {"temperature": 0.7, "max_output_tokens": 1024} | |
safety_settings: Optional[List[Dict[str, Any]]] = None, # e.g., [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'threshold': 'BLOCK_NONE'}] | |
system_instruction: Optional[str] = None, | |
thinking_config: Optional[Dict[str, Any]] = None # e.g., {"temperature": 0.7, "max_output_tokens": 1024} | |
) -> StateDict: | |
""" | |
Queries a Google Generative AI model for a list of prompts with retries and state persistence. | |
Loads existing results from `output_filepath`, processes only prompts not yet successfully | |
completed, and saves results incrementally. | |
Args: | |
prompts: A list of unique input strings (prompts) to send to the LLM. | |
**Important**: Prompts should be unique identifiers for tasks. | |
output_filepath: Path to the JSON file where results are stored and loaded from. | |
model_name: The name of the Google Generative AI model to use. | |
api_key: Your Google AI API key (required if not using Vertex AI). | |
Can also be set via the GOOGLE_API_KEY environment variable. | |
vertex_ai_config: Dictionary with 'project' and 'location' for Vertex AI. | |
Alternatively, set GOOGLE_GENAI_USE_VERTEXAI=true, | |
GOOGLE_CLOUD_PROJECT, and GOOGLE_CLOUD_LOCATION env vars. | |
max_retries: Maximum number of retries for timeout or specific API errors. | |
retry_delay_seconds: Delay in seconds between retries. | |
api_timeout_seconds: Timeout for the API call in seconds. | |
generation_config_dict: Optional dictionary for generation parameters | |
(e.g., temperature, max_output_tokens). | |
safety_settings: Optional list of safety settings dictionaries. | |
system_instruction: Optional system instruction string. | |
thinking_config: Optional dictionary for thinking configuration parameters. | |
Returns: | |
A dictionary mapping each input prompt (from the original list) to its | |
final state dictionary (`{"prompt": ..., "status": ..., "result": ...}`). | |
This dictionary reflects the state saved in `output_filepath`. | |
""" | |
output_path = Path(output_filepath) | |
# --- State Loading --- | |
# Load existing state. Keys are prompts, values are state dictionaries. | |
current_state = _load_state_from_jsonl(output_path) | |
# --- Initialize potentially missing prompts in results --- | |
# This ensures that even if the script fails before processing a prompt, | |
# its key exists in the final returned dictionary, mapped to None. | |
# It also helps track which prompts from the *current* list need processing. | |
prompts_to_process = [] | |
initial_save_needed = False # Flag if we add new pending prompts | |
for prompt in prompts: | |
if prompt not in current_state: | |
# New prompt, add to state as pending | |
current_state[prompt] = {"prompt": prompt, "status": STATUS_PENDING, "result": None} | |
prompts_to_process.append(prompt) | |
initial_save_needed = True | |
logging.debug(f"Prompt '{prompt[:50]}...' added as new and pending.") | |
elif current_state[prompt]["status"] in [STATUS_PENDING, STATUS_FAILED]: | |
# Existing prompt that is pending or failed, needs processing | |
# Ensure status is PENDING before processing if it was failed before | |
if current_state[prompt]["status"] == STATUS_FAILED: | |
logging.info(f"Retrying previously failed prompt: '{prompt[:50]}...'") | |
# Reset status to pending for the retry attempt. Keep old error in result for now. | |
# current_state[prompt]["status"] = STATUS_PENDING # Let the processing loop update status | |
prompts_to_process.append(prompt) | |
# else: prompt already has status STATUS_SUCCESSFUL, skip it. | |
if initial_save_needed: | |
# Save the state immediately if new prompts were added as pending | |
_save_state_to_jsonl(output_path, current_state) | |
if not prompts_to_process: | |
logging.info("All provided prompts are already marked as 'successful' in the state file. Nothing new to process.") | |
return current_state # Return the loaded/updated state | |
logging.info(f"Found {len(prompts_to_process)} prompts to process (pending or previously failed) out of {len(prompts)} total provided.") | |
# --- Client and Config Initialization --- | |
# Configure HTTP options with timeout | |
api_timeout_milliseconds = api_timeout_seconds * 1000 | |
http_opts = genai_types.HttpOptions(timeout=api_timeout_milliseconds) | |
# Initialize the GenAI client | |
try: | |
client = _initialize_genai_client(api_key, vertex_ai_config, http_opts) | |
except ValueError as e: | |
logging.error(f"Client Initialization Error: {e}. Cannot proceed.") | |
# Mark all prompts intended for processing in this run as failed if client init fails | |
for prompt in prompts_to_process: | |
if current_state[prompt]["status"] == STATUS_PENDING: # Only update if still pending | |
current_state[prompt]["status"] = STATUS_FAILED | |
current_state[prompt]["result"] = "Error: Client initialization failed" | |
_save_state_to_jsonl(output_path, current_state) # Save the updated failure status | |
return current_state # Return the state reflecting the init failure | |
# Construct GenerateContentConfig if parameters are provided | |
generate_content_config_args = {} | |
if generation_config_dict: | |
filtered_gen_config = {k: v for k, v in generation_config_dict.items() if v is not None} | |
if filtered_gen_config: | |
generate_content_config_args.update(filtered_gen_config) | |
if safety_settings: | |
generate_content_config_args['safety_settings'] = safety_settings | |
if system_instruction: | |
generate_content_config_args['system_instruction'] = system_instruction | |
if thinking_config: | |
generate_content_config_args['thinking_config'] = thinking_config | |
llm_config = genai_types.GenerateContentConfig(**generate_content_config_args) if generate_content_config_args else None | |
# --- Process Prompts --- | |
total_to_process = len(prompts_to_process) | |
newly_processed_count = 0 | |
newly_failed_count = 0 | |
for i, prompt in enumerate(prompts_to_process): | |
# Find original index for logging consistency if needed (optional) | |
try: | |
original_index = prompts.index(prompt) + 1 | |
log_prefix = f"Prompt ({i+1}/{total_to_process}, Original #{original_index}/{len(prompts)})" | |
except ValueError: # Should not happen if prompts_to_process comes from prompts | |
log_prefix = f"Prompt ({i+1}/{total_to_process})" | |
logging.info(f"{log_prefix}: Querying model for: '{prompt[:80]}...'") | |
current_retries = 0 | |
# Status updates will modify current_state[prompt] directly | |
prompt_data = current_state[prompt] # Get reference to the state dict for this prompt | |
prompt_data["status"] = STATUS_PENDING # Ensure it's marked as pending for this attempt | |
while current_retries <= max_retries: | |
request_successful = False | |
permanent_error = False | |
error_message = None | |
try: | |
response = client.models.generate_content( | |
model=model_name, | |
contents=prompt, | |
config=llm_config | |
) | |
# --- Process Response --- | |
# Check for blocked content first | |
if response.prompt_feedback and response.prompt_feedback.block_reason: | |
reason = response.prompt_feedback.block_reason.name | |
message = response.prompt_feedback.block_reason_message or "No specific message." | |
error_message = f"Error: Blocked (prompt) - {reason}. Message: {message}" | |
logging.warning(f"{log_prefix}: Failed - {error_message}") | |
permanent_error = True | |
break # Exit retry loop for prompt blocking | |
# Check if response has text | |
try: | |
response_text = response.text # Handles candidate checking internally | |
# Try extracting reasoning | |
if response.candidates and response.candidates[0].content.parts[0].thought: | |
reasoning = response.candidates[0].content.parts[0].thought | |
logging.info(f"{log_prefix}: Reasoning: {reasoning}") | |
response_text = f"<think>{reasoning}</think>\n\n{response_text}" | |
logging.info(f"{log_prefix}: Successfully received response.") | |
prompt_data["status"] = STATUS_SUCCESSFUL | |
prompt_data["result"] = response_text | |
request_successful = True | |
break # Exit retry loop on success | |
except ValueError as e: | |
# Handle cases where .text property fails (e.g., no valid candidate, safety block on response) | |
finish_reason_str = "Unknown reason" | |
finish_message = "" | |
block_reason_str = "Unknown reason" | |
response_blocked = False | |
if response.candidates: | |
cand = response.candidates[0] | |
if cand.finish_reason: finish_reason_str = cand.finish_reason.name | |
if cand.finish_message: finish_message = f" Message: {cand.finish_message}" | |
if cand.safety_ratings: | |
for rating in cand.safety_ratings: | |
if rating.blocked: | |
block_reason_str = rating.category.name | |
error_message = f"Error: Blocked (response) due to {block_reason_str}." | |
logging.warning(f"{log_prefix}: Failed - {error_message}") | |
response_blocked = True | |
permanent_error = True | |
break # Break inner safety check loop | |
if response_blocked: break # Break retry loop if response blocked | |
# If not blocked, log the general extraction failure | |
error_message = f"Error: Could not extract text. Finish Reason: {finish_reason_str}{finish_message}. Details: {e}" | |
logging.warning(f"{log_prefix}: Failed - {error_message}") | |
permanent_error = True # Consider this permanent | |
break # Exit retry loop | |
# --- Error Handling & Retries --- | |
except RETRYABLE_EXCEPTIONS as e: | |
current_retries += 1 | |
error_message = f"Retryable error (Attempt {current_retries}/{max_retries}): {type(e).__name__} - {e}" | |
logging.warning(f"{log_prefix}: {error_message}") | |
if current_retries > max_retries: | |
error_message = f"Error: Failed after {max_retries} retries due to {type(e).__name__}." | |
logging.error(f"{log_prefix}: {error_message}") | |
permanent_error = True # Exhausted retries | |
break | |
else: | |
logging.info(f"Waiting {retry_delay_seconds}s before retrying...") | |
time.sleep(retry_delay_seconds) | |
# Loop continues automatically | |
except APIError as e: | |
error_code = e.code | |
is |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment