Created
June 14, 2025 16:01
-
-
Save praveenc/bd02f639d39d6194374632f11957d10e to your computer and use it in GitHub Desktop.
Python util script to invoke Amazon Bedrock using ConverseAPI. Supports Document upload, Prompt Caching, Reasoning Enabled.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import re | |
| import time | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any, ClassVar, Final, Literal | |
| import boto3 | |
| from botocore.config import Config | |
| from botocore.exceptions import ClientError | |
| from loguru import logger | |
| # Constants | |
| DEFAULT_REGION: Final[str] = "us-west-2" | |
| DEFAULT_TIMEOUT: Final[int] = 1000 | |
| MAX_RETRIES: Final[int] = 5 | |
| BASE_RETRY_DELAY: Final[int] = 5 | |
| MIN_BUDGET_TOKENS: Final[int] = 1024 | |
| SUPPORTED_DOC_FORMATS: Final[frozenset[str]] = frozenset( | |
| { | |
| "pdf", | |
| "csv", | |
| "doc", | |
| "docx", | |
| "xls", | |
| "xlsx", | |
| "html", | |
| "txt", | |
| "md", | |
| }, | |
| ) | |
| # Document name sanitization pattern | |
| DOC_NAME_PATTERN: Final[re.Pattern] = re.compile(r"[\\/_,.]") | |
| @dataclass(frozen=True) | |
| class CacheConfig: | |
| """Configuration for prompt caching.""" | |
| min_tokens: int | |
| max_checkpoints: int | |
| class ConverseHelper: | |
| """Helper class for managing conversation messages.""" | |
| def __init__( | |
| self, | |
| prompt: str | None = None, | |
| role: Literal["user", "assistant"] = "user", | |
| ) -> None: | |
| self.messages: list[dict[str, Any]] = [] | |
| self.role = role | |
| if prompt: | |
| self.add_message(prompt) | |
| def add_message( | |
| self, | |
| content: str, | |
| role: Literal["user", "assistant"] = "user", | |
| ) -> None: | |
| """Add a message to the conversation.""" | |
| self.role = role | |
| if not self.messages: | |
| self.messages.append({"role": self.role, "content": [{"text": content}]}) | |
| else: | |
| self.messages[-1]["content"].append({"text": content}) | |
| def add_cache_point(self) -> None: | |
| """Add a cache point to the conversation.""" | |
| if self.messages: | |
| self.messages[-1]["content"].append({"cachePoint": {"type": "default"}}) | |
| def get_messages(self) -> list[dict[str, Any]]: | |
| """Get the conversation messages.""" | |
| return self.messages | |
| class BedrockControlPlane: | |
| """Control plane for Amazon Bedrock operations.""" | |
| def __init__(self, region_name: str = DEFAULT_REGION) -> None: | |
| self.client = boto3.client("bedrock", region_name=region_name) | |
| self.inference_profiles_type: Literal["SYSTEM_DEFINED", "APPLICATION"] = ( | |
| "SYSTEM_DEFINED" | |
| ) | |
| def list_models(self, provider: str | None = None) -> list[dict[str, Any]]: | |
| """List Amazon Bedrock foundation models.""" | |
| response = self.client.list_foundation_models() | |
| models = [ | |
| model | |
| for model in response["modelSummaries"] | |
| if model["modelLifecycle"]["status"] == "ACTIVE" | |
| ] | |
| if provider: | |
| models = [ | |
| model | |
| for model in models | |
| if model["providerName"].strip().lower() == provider.strip().lower() | |
| ] | |
| return models | |
| def describe_model(self, model_id: str) -> dict[str, Any]: | |
| """Describe a foundation model.""" | |
| response = self.client.get_foundation_model(modelIdentifier=model_id) | |
| return response["modelDetails"] | |
| def list_inference_profiles( | |
| self, | |
| provider: str | None = None, | |
| ) -> list[dict[str, Any]]: | |
| """List Amazon Bedrock inference profiles.""" | |
| try: | |
| response = self.client.list_inference_profiles( | |
| typeEquals=self.inference_profiles_type, | |
| ) | |
| profiles = response["inferenceProfileSummaries"] | |
| if provider: | |
| logger.info(f"Filtering inference profiles by provider: {provider}") | |
| provider_lower = provider.strip().lower() | |
| profiles = [ | |
| profile | |
| for profile in profiles | |
| if ( | |
| provider_lower in profile["inferenceProfileId"].strip().lower() | |
| and profile["status"] == "ACTIVE" | |
| ) | |
| ] | |
| return profiles | |
| except ClientError as e: | |
| if e.response["Error"]["Code"] == "ResourceNotFoundException": | |
| logger.warning( | |
| "Inference profiles not found. This feature may not be available in your region.", | |
| ) | |
| return [] | |
| raise | |
| class BedrockAssistant: | |
| """Amazon Bedrock assistant for conversational AI.""" | |
| # Prompt cache configuration for supported models | |
| PROMPT_CACHE_SUPPORTED_MODELS: ClassVar[dict[str, CacheConfig]] = { | |
| "anthropic.claude-opus-4-20250514-v1:0": CacheConfig( | |
| min_tokens=1024, | |
| max_checkpoints=4, | |
| ), | |
| "anthropic.claude-sonnet-4-20250514-v1:0": CacheConfig( | |
| min_tokens=1024, | |
| max_checkpoints=4, | |
| ), | |
| "anthropic.claude-3-7-sonnet-20250219-v1:0": CacheConfig( | |
| min_tokens=1024, | |
| max_checkpoints=4, | |
| ), | |
| "anthropic.claude-3-5-haiku-20241022-v1:0": CacheConfig( | |
| min_tokens=2048, | |
| max_checkpoints=4, | |
| ), | |
| "anthropic.claude-3-5-sonnet-20241022-v2:0": CacheConfig( | |
| min_tokens=1024, | |
| max_checkpoints=4, | |
| ), | |
| "amazon.nova-micro-v1:0": CacheConfig(min_tokens=1024, max_checkpoints=4), | |
| "amazon.nova-lite-v1:0": CacheConfig(min_tokens=1024, max_checkpoints=4), | |
| "amazon.nova-pro-v1:0": CacheConfig(min_tokens=1024, max_checkpoints=4), | |
| } | |
| def __init__( # noqa: PLR0913 | |
| self, | |
| model_id: str, | |
| inference_config: dict[str, Any], | |
| system_prompt: str = "", | |
| *, | |
| enabled_reasoning: bool = False, | |
| budget_tokens: int = MIN_BUDGET_TOKENS, | |
| region_name: str = DEFAULT_REGION, | |
| enable_prompt_caching: bool = True, | |
| ) -> None: | |
| """Initialize Bedrock Assistant.""" | |
| self._setup_client(region_name) | |
| self.model_id = model_id | |
| self.enabled_reasoning = enabled_reasoning | |
| self.enable_prompt_caching = enable_prompt_caching | |
| # Setup prompt caching | |
| self._configure_prompt_caching() | |
| # Setup system prompt | |
| self.system_prompt = self._prepare_system_prompt(system_prompt) | |
| # Setup inference configuration | |
| self.inference_config = self._validate_inference_config( | |
| inference_config, | |
| enabled_reasoning=enabled_reasoning, | |
| budget_tokens=budget_tokens, | |
| ) | |
| self.budget_tokens = ( | |
| max(budget_tokens, MIN_BUDGET_TOKENS) | |
| if enabled_reasoning | |
| else budget_tokens | |
| ) | |
| def _setup_client(self, region_name: str) -> None: | |
| """Setup the Bedrock runtime client.""" | |
| config = Config(read_timeout=DEFAULT_TIMEOUT) | |
| self.client = boto3.client( | |
| "bedrock-runtime", | |
| region_name=region_name, | |
| config=config, | |
| ) | |
| def _configure_prompt_caching(self) -> None: | |
| """Configure prompt caching settings.""" | |
| model_base_id = self._get_base_model_id() | |
| if model_base_id in self.PROMPT_CACHE_SUPPORTED_MODELS: | |
| self.prompt_caching_supported = True | |
| cache_config = self.PROMPT_CACHE_SUPPORTED_MODELS[model_base_id] | |
| self.prompt_cache_min_tokens = cache_config.min_tokens | |
| self.prompt_cache_max_checkpoints = cache_config.max_checkpoints | |
| if self.enable_prompt_caching: | |
| logger.info(f"Prompt caching enabled for model {self.model_id}") | |
| logger.info( | |
| f"Minimum tokens per cache checkpoint: {self.prompt_cache_min_tokens}", | |
| ) | |
| else: | |
| self.prompt_caching_supported = False | |
| if self.enable_prompt_caching: | |
| logger.warning( | |
| f"Prompt caching requested but not supported for model {self.model_id}" | |
| ) | |
| self.enable_prompt_caching = False | |
| def _get_base_model_id(self) -> str: | |
| """Extract base model ID from potentially prefixed model ID.""" | |
| return self.model_id.removeprefix("us.") | |
| def _prepare_system_prompt(self, system_prompt: str) -> list[dict[str, Any]]: | |
| """Prepare system prompt with optional cache point.""" | |
| if not system_prompt: | |
| return [] | |
| prompt_content = [{"text": system_prompt}] | |
| if self.enable_prompt_caching and self.prompt_caching_supported: | |
| prompt_content.append({"cachePoint": {"type": "default"}}) | |
| return prompt_content | |
| def _validate_inference_config( | |
| self, | |
| inference_config: dict[str, Any], | |
| *, | |
| enabled_reasoning: bool, | |
| budget_tokens: int, | |
| ) -> dict[str, Any]: | |
| """Validate and adjust inference configuration for reasoning mode.""" | |
| config = inference_config.copy() | |
| if not enabled_reasoning: | |
| return config | |
| # Reasoning mode requirements | |
| if config.get("temperature") != 1: | |
| logger.warning( | |
| "Temperature must be set to 1 when reasoning is enabled. Setting to 1." | |
| ) | |
| config["temperature"] = 1 | |
| if "topP" in config: | |
| logger.warning("Top P must be unset when reasoning is enabled. Unsetting.") | |
| del config["topP"] | |
| if config.get("maxTokens", 0) <= budget_tokens: | |
| new_max_tokens = ( | |
| config.get("maxTokens", MIN_BUDGET_TOKENS) + MIN_BUDGET_TOKENS | |
| ) | |
| logger.warning( | |
| f"Max tokens must be greater than budget tokens when reasoning is enabled. " | |
| f"Setting to maxTokens={new_max_tokens}.", | |
| ) | |
| config["maxTokens"] = new_max_tokens | |
| return config | |
| def set_inference_config(self, inference_config: dict[str, Any]) -> None: | |
| """Set the default inference config for subsequent requests.""" | |
| self.inference_config = inference_config | |
| @staticmethod | |
| @lru_cache(maxsize=128) | |
| def _validate_document_format(doc_format: str) -> str: | |
| """Validate document format with caching.""" | |
| if doc_format not in SUPPORTED_DOC_FORMATS: | |
| supported_formats = ", ".join(sorted(SUPPORTED_DOC_FORMATS)) | |
| error_msg = f"Unsupported document format: {doc_format}. Supported formats: {supported_formats}" | |
| logger.warning(f"Supported document formats are: {supported_formats}") | |
| logger.error(error_msg) | |
| raise ValueError(error_msg) | |
| return doc_format | |
| @staticmethod | |
| def _sanitize_document_name(doc_path: Path) -> str: | |
| """Sanitize document name for Bedrock requirements.""" | |
| # Remove spaces and replace problematic characters with hyphens | |
| name_without_spaces = "".join(doc_path.name.split()) | |
| return DOC_NAME_PATTERN.sub("-", name_without_spaces) | |
| def prepare_document_message(self, doc_path: Path) -> dict[str, Any]: | |
| """Prepare a document message for the conversation.""" | |
| if not doc_path.exists(): | |
| error_msg = f"Document file not found: {doc_path}" | |
| logger.error(error_msg) | |
| raise FileNotFoundError(error_msg) | |
| try: | |
| doc_bytes = doc_path.read_bytes() | |
| except Exception as e: | |
| error_msg = f"Failed to read document: {doc_path}" | |
| logger.error(f"{error_msg}: {e}") | |
| raise RuntimeError(error_msg) from e | |
| doc_format = self._validate_document_format(doc_path.suffix[1:].lower()) | |
| doc_name = self._sanitize_document_name(doc_path) | |
| logger.info(f"Using document name: {doc_name}, format: {doc_format}") | |
| return { | |
| "name": doc_name, | |
| "format": doc_format, | |
| "source": {"bytes": doc_bytes}, | |
| } | |
| def converse( | |
| self, | |
| messages: list[dict[str, Any]], | |
| doc_path: Path | None = None, | |
| return_type: str = "text", | |
| *, | |
| add_cache_point: bool = True, | |
| ) -> str | None: | |
| """Send a converse request to Amazon Bedrock with retry logic.""" | |
| for attempt in range(MAX_RETRIES + 1): | |
| try: | |
| return self._make_converse_request( | |
| messages, | |
| doc_path, | |
| return_type, | |
| add_cache_point, | |
| ) | |
| except ClientError as e: | |
| if e.response["Error"]["Code"] == "ThrottlingException": | |
| if attempt >= MAX_RETRIES: | |
| msg = "Max retries exceeded. Please try again later." | |
| logger.error(msg) | |
| raise RuntimeError(msg) from e | |
| delay = BASE_RETRY_DELAY * (2 ** (attempt + 1)) | |
| logger.warning( | |
| f"ThrottlingException: Retrying in {delay} seconds..." | |
| ) | |
| time.sleep(delay) | |
| continue | |
| error_msg = f"Bedrock API error: {e.response['Error']}" | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) from e | |
| except Exception as e: | |
| error_msg = "Unexpected error during Bedrock API call" | |
| logger.error(f"{error_msg}: {e}") | |
| raise RuntimeError(error_msg) from e | |
| logger.error("Exhausted all retry attempts") | |
| return None | |
| def _make_converse_request( | |
| self, | |
| messages: list[dict[str, Any]], | |
| doc_path: Path | None, | |
| return_type: str, | |
| add_cache_point: bool = True, | |
| ) -> str: | |
| """Make the actual converse request to Bedrock.""" | |
| # Prepare additional model request fields for reasoning | |
| additional_fields = None | |
| if self.enabled_reasoning and self.budget_tokens > 0: | |
| additional_fields = { | |
| "reasoning_config": { | |
| "type": "enabled", | |
| "budget_tokens": self.budget_tokens, | |
| }, | |
| } | |
| # Add document if provided | |
| if doc_path: | |
| document_dict = self.prepare_document_message(doc_path) | |
| messages[-1]["content"].append({"document": document_dict}) | |
| # logger.info(f"Invoking model: {self.model_id} using Converse API...") | |
| response_data = self.client.converse( | |
| modelId=self.model_id, | |
| messages=messages, | |
| inferenceConfig=self.inference_config, | |
| additionalModelRequestFields=additional_fields, | |
| system=self.system_prompt, | |
| ) | |
| self._log_cache_usage(response_data) | |
| return self._process_response(response_data, return_type) | |
| def _log_cache_usage(self, response_data: dict[str, Any]) -> None: | |
| """Log cache token usage if prompt caching was used.""" | |
| if not ( | |
| self.enable_prompt_caching | |
| and self.prompt_caching_supported | |
| and "usage" in response_data | |
| ): | |
| return | |
| usage = response_data["usage"] | |
| # logger.info(usage) | |
| cache_read_tokens = usage.get("cacheReadInputTokens", 0) | |
| cache_write_tokens = usage.get("cacheWriteInputTokens", 0) | |
| if cache_read_tokens > 0 or cache_write_tokens > 0: | |
| logger.info( | |
| f"Cache read tokens: {cache_read_tokens}, Cache write tokens: {cache_write_tokens}", | |
| ) | |
| def _process_response( | |
| self, | |
| response: dict[str, Any], | |
| return_type: str = "text", | |
| ) -> str: | |
| """Process the raw API response into a user-friendly format.""" | |
| content = response["output"]["message"]["content"] | |
| if not self.enabled_reasoning: | |
| return content[0]["text"] | |
| return ( | |
| content[0]["reasoningContent"]["reasoningText"]["text"] | |
| if return_type == "thinking" | |
| else content[1]["text"] | |
| ) | |
| # Test code | |
| if __name__ == "__main__": | |
| from rich import print as rprint | |
| helper = ConverseHelper() | |
| fixed_prompt = Path("prompts/test_prompt.txt").read_text( | |
| encoding="utf-8", | |
| ) | |
| helper.add_message(fixed_prompt, role="user") | |
| helper.add_cache_point() | |
| inference_config = { | |
| "temperature": 0.2, | |
| "topP": 0.5, | |
| "maxTokens": 10240, | |
| } | |
| assistant = BedrockAssistant( | |
| model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", | |
| inference_config=inference_config, | |
| enabled_reasoning=True, | |
| enable_prompt_caching=True, | |
| ) | |
| response = assistant.converse( | |
| doc_path=Path("test_document.txt"), | |
| messages=helper.get_messages(), | |
| ) | |
| rprint(response) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment