Skip to content

Instantly share code, notes, and snippets.

@praveenc
Created June 14, 2025 16:01
Show Gist options
  • Save praveenc/bd02f639d39d6194374632f11957d10e to your computer and use it in GitHub Desktop.
Save praveenc/bd02f639d39d6194374632f11957d10e to your computer and use it in GitHub Desktop.
Python util script to invoke Amazon Bedrock using ConverseAPI. Supports Document upload, Prompt Caching, Reasoning Enabled.
import re
import time
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, ClassVar, Final, Literal
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
from loguru import logger
# Constants
DEFAULT_REGION: Final[str] = "us-west-2"
DEFAULT_TIMEOUT: Final[int] = 1000
MAX_RETRIES: Final[int] = 5
BASE_RETRY_DELAY: Final[int] = 5
MIN_BUDGET_TOKENS: Final[int] = 1024
SUPPORTED_DOC_FORMATS: Final[frozenset[str]] = frozenset(
{
"pdf",
"csv",
"doc",
"docx",
"xls",
"xlsx",
"html",
"txt",
"md",
},
)
# Document name sanitization pattern
DOC_NAME_PATTERN: Final[re.Pattern] = re.compile(r"[\\/_,.]")
@dataclass(frozen=True)
class CacheConfig:
"""Configuration for prompt caching."""
min_tokens: int
max_checkpoints: int
class ConverseHelper:
"""Helper class for managing conversation messages."""
def __init__(
self,
prompt: str | None = None,
role: Literal["user", "assistant"] = "user",
) -> None:
self.messages: list[dict[str, Any]] = []
self.role = role
if prompt:
self.add_message(prompt)
def add_message(
self,
content: str,
role: Literal["user", "assistant"] = "user",
) -> None:
"""Add a message to the conversation."""
self.role = role
if not self.messages:
self.messages.append({"role": self.role, "content": [{"text": content}]})
else:
self.messages[-1]["content"].append({"text": content})
def add_cache_point(self) -> None:
"""Add a cache point to the conversation."""
if self.messages:
self.messages[-1]["content"].append({"cachePoint": {"type": "default"}})
def get_messages(self) -> list[dict[str, Any]]:
"""Get the conversation messages."""
return self.messages
class BedrockControlPlane:
"""Control plane for Amazon Bedrock operations."""
def __init__(self, region_name: str = DEFAULT_REGION) -> None:
self.client = boto3.client("bedrock", region_name=region_name)
self.inference_profiles_type: Literal["SYSTEM_DEFINED", "APPLICATION"] = (
"SYSTEM_DEFINED"
)
def list_models(self, provider: str | None = None) -> list[dict[str, Any]]:
"""List Amazon Bedrock foundation models."""
response = self.client.list_foundation_models()
models = [
model
for model in response["modelSummaries"]
if model["modelLifecycle"]["status"] == "ACTIVE"
]
if provider:
models = [
model
for model in models
if model["providerName"].strip().lower() == provider.strip().lower()
]
return models
def describe_model(self, model_id: str) -> dict[str, Any]:
"""Describe a foundation model."""
response = self.client.get_foundation_model(modelIdentifier=model_id)
return response["modelDetails"]
def list_inference_profiles(
self,
provider: str | None = None,
) -> list[dict[str, Any]]:
"""List Amazon Bedrock inference profiles."""
try:
response = self.client.list_inference_profiles(
typeEquals=self.inference_profiles_type,
)
profiles = response["inferenceProfileSummaries"]
if provider:
logger.info(f"Filtering inference profiles by provider: {provider}")
provider_lower = provider.strip().lower()
profiles = [
profile
for profile in profiles
if (
provider_lower in profile["inferenceProfileId"].strip().lower()
and profile["status"] == "ACTIVE"
)
]
return profiles
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning(
"Inference profiles not found. This feature may not be available in your region.",
)
return []
raise
class BedrockAssistant:
"""Amazon Bedrock assistant for conversational AI."""
# Prompt cache configuration for supported models
PROMPT_CACHE_SUPPORTED_MODELS: ClassVar[dict[str, CacheConfig]] = {
"anthropic.claude-opus-4-20250514-v1:0": CacheConfig(
min_tokens=1024,
max_checkpoints=4,
),
"anthropic.claude-sonnet-4-20250514-v1:0": CacheConfig(
min_tokens=1024,
max_checkpoints=4,
),
"anthropic.claude-3-7-sonnet-20250219-v1:0": CacheConfig(
min_tokens=1024,
max_checkpoints=4,
),
"anthropic.claude-3-5-haiku-20241022-v1:0": CacheConfig(
min_tokens=2048,
max_checkpoints=4,
),
"anthropic.claude-3-5-sonnet-20241022-v2:0": CacheConfig(
min_tokens=1024,
max_checkpoints=4,
),
"amazon.nova-micro-v1:0": CacheConfig(min_tokens=1024, max_checkpoints=4),
"amazon.nova-lite-v1:0": CacheConfig(min_tokens=1024, max_checkpoints=4),
"amazon.nova-pro-v1:0": CacheConfig(min_tokens=1024, max_checkpoints=4),
}
def __init__( # noqa: PLR0913
self,
model_id: str,
inference_config: dict[str, Any],
system_prompt: str = "",
*,
enabled_reasoning: bool = False,
budget_tokens: int = MIN_BUDGET_TOKENS,
region_name: str = DEFAULT_REGION,
enable_prompt_caching: bool = True,
) -> None:
"""Initialize Bedrock Assistant."""
self._setup_client(region_name)
self.model_id = model_id
self.enabled_reasoning = enabled_reasoning
self.enable_prompt_caching = enable_prompt_caching
# Setup prompt caching
self._configure_prompt_caching()
# Setup system prompt
self.system_prompt = self._prepare_system_prompt(system_prompt)
# Setup inference configuration
self.inference_config = self._validate_inference_config(
inference_config,
enabled_reasoning=enabled_reasoning,
budget_tokens=budget_tokens,
)
self.budget_tokens = (
max(budget_tokens, MIN_BUDGET_TOKENS)
if enabled_reasoning
else budget_tokens
)
def _setup_client(self, region_name: str) -> None:
"""Setup the Bedrock runtime client."""
config = Config(read_timeout=DEFAULT_TIMEOUT)
self.client = boto3.client(
"bedrock-runtime",
region_name=region_name,
config=config,
)
def _configure_prompt_caching(self) -> None:
"""Configure prompt caching settings."""
model_base_id = self._get_base_model_id()
if model_base_id in self.PROMPT_CACHE_SUPPORTED_MODELS:
self.prompt_caching_supported = True
cache_config = self.PROMPT_CACHE_SUPPORTED_MODELS[model_base_id]
self.prompt_cache_min_tokens = cache_config.min_tokens
self.prompt_cache_max_checkpoints = cache_config.max_checkpoints
if self.enable_prompt_caching:
logger.info(f"Prompt caching enabled for model {self.model_id}")
logger.info(
f"Minimum tokens per cache checkpoint: {self.prompt_cache_min_tokens}",
)
else:
self.prompt_caching_supported = False
if self.enable_prompt_caching:
logger.warning(
f"Prompt caching requested but not supported for model {self.model_id}"
)
self.enable_prompt_caching = False
def _get_base_model_id(self) -> str:
"""Extract base model ID from potentially prefixed model ID."""
return self.model_id.removeprefix("us.")
def _prepare_system_prompt(self, system_prompt: str) -> list[dict[str, Any]]:
"""Prepare system prompt with optional cache point."""
if not system_prompt:
return []
prompt_content = [{"text": system_prompt}]
if self.enable_prompt_caching and self.prompt_caching_supported:
prompt_content.append({"cachePoint": {"type": "default"}})
return prompt_content
def _validate_inference_config(
self,
inference_config: dict[str, Any],
*,
enabled_reasoning: bool,
budget_tokens: int,
) -> dict[str, Any]:
"""Validate and adjust inference configuration for reasoning mode."""
config = inference_config.copy()
if not enabled_reasoning:
return config
# Reasoning mode requirements
if config.get("temperature") != 1:
logger.warning(
"Temperature must be set to 1 when reasoning is enabled. Setting to 1."
)
config["temperature"] = 1
if "topP" in config:
logger.warning("Top P must be unset when reasoning is enabled. Unsetting.")
del config["topP"]
if config.get("maxTokens", 0) <= budget_tokens:
new_max_tokens = (
config.get("maxTokens", MIN_BUDGET_TOKENS) + MIN_BUDGET_TOKENS
)
logger.warning(
f"Max tokens must be greater than budget tokens when reasoning is enabled. "
f"Setting to maxTokens={new_max_tokens}.",
)
config["maxTokens"] = new_max_tokens
return config
def set_inference_config(self, inference_config: dict[str, Any]) -> None:
"""Set the default inference config for subsequent requests."""
self.inference_config = inference_config
@staticmethod
@lru_cache(maxsize=128)
def _validate_document_format(doc_format: str) -> str:
"""Validate document format with caching."""
if doc_format not in SUPPORTED_DOC_FORMATS:
supported_formats = ", ".join(sorted(SUPPORTED_DOC_FORMATS))
error_msg = f"Unsupported document format: {doc_format}. Supported formats: {supported_formats}"
logger.warning(f"Supported document formats are: {supported_formats}")
logger.error(error_msg)
raise ValueError(error_msg)
return doc_format
@staticmethod
def _sanitize_document_name(doc_path: Path) -> str:
"""Sanitize document name for Bedrock requirements."""
# Remove spaces and replace problematic characters with hyphens
name_without_spaces = "".join(doc_path.name.split())
return DOC_NAME_PATTERN.sub("-", name_without_spaces)
def prepare_document_message(self, doc_path: Path) -> dict[str, Any]:
"""Prepare a document message for the conversation."""
if not doc_path.exists():
error_msg = f"Document file not found: {doc_path}"
logger.error(error_msg)
raise FileNotFoundError(error_msg)
try:
doc_bytes = doc_path.read_bytes()
except Exception as e:
error_msg = f"Failed to read document: {doc_path}"
logger.error(f"{error_msg}: {e}")
raise RuntimeError(error_msg) from e
doc_format = self._validate_document_format(doc_path.suffix[1:].lower())
doc_name = self._sanitize_document_name(doc_path)
logger.info(f"Using document name: {doc_name}, format: {doc_format}")
return {
"name": doc_name,
"format": doc_format,
"source": {"bytes": doc_bytes},
}
def converse(
self,
messages: list[dict[str, Any]],
doc_path: Path | None = None,
return_type: str = "text",
*,
add_cache_point: bool = True,
) -> str | None:
"""Send a converse request to Amazon Bedrock with retry logic."""
for attempt in range(MAX_RETRIES + 1):
try:
return self._make_converse_request(
messages,
doc_path,
return_type,
add_cache_point,
)
except ClientError as e:
if e.response["Error"]["Code"] == "ThrottlingException":
if attempt >= MAX_RETRIES:
msg = "Max retries exceeded. Please try again later."
logger.error(msg)
raise RuntimeError(msg) from e
delay = BASE_RETRY_DELAY * (2 ** (attempt + 1))
logger.warning(
f"ThrottlingException: Retrying in {delay} seconds..."
)
time.sleep(delay)
continue
error_msg = f"Bedrock API error: {e.response['Error']}"
logger.error(error_msg)
raise RuntimeError(error_msg) from e
except Exception as e:
error_msg = "Unexpected error during Bedrock API call"
logger.error(f"{error_msg}: {e}")
raise RuntimeError(error_msg) from e
logger.error("Exhausted all retry attempts")
return None
def _make_converse_request(
self,
messages: list[dict[str, Any]],
doc_path: Path | None,
return_type: str,
add_cache_point: bool = True,
) -> str:
"""Make the actual converse request to Bedrock."""
# Prepare additional model request fields for reasoning
additional_fields = None
if self.enabled_reasoning and self.budget_tokens > 0:
additional_fields = {
"reasoning_config": {
"type": "enabled",
"budget_tokens": self.budget_tokens,
},
}
# Add document if provided
if doc_path:
document_dict = self.prepare_document_message(doc_path)
messages[-1]["content"].append({"document": document_dict})
# logger.info(f"Invoking model: {self.model_id} using Converse API...")
response_data = self.client.converse(
modelId=self.model_id,
messages=messages,
inferenceConfig=self.inference_config,
additionalModelRequestFields=additional_fields,
system=self.system_prompt,
)
self._log_cache_usage(response_data)
return self._process_response(response_data, return_type)
def _log_cache_usage(self, response_data: dict[str, Any]) -> None:
"""Log cache token usage if prompt caching was used."""
if not (
self.enable_prompt_caching
and self.prompt_caching_supported
and "usage" in response_data
):
return
usage = response_data["usage"]
# logger.info(usage)
cache_read_tokens = usage.get("cacheReadInputTokens", 0)
cache_write_tokens = usage.get("cacheWriteInputTokens", 0)
if cache_read_tokens > 0 or cache_write_tokens > 0:
logger.info(
f"Cache read tokens: {cache_read_tokens}, Cache write tokens: {cache_write_tokens}",
)
def _process_response(
self,
response: dict[str, Any],
return_type: str = "text",
) -> str:
"""Process the raw API response into a user-friendly format."""
content = response["output"]["message"]["content"]
if not self.enabled_reasoning:
return content[0]["text"]
return (
content[0]["reasoningContent"]["reasoningText"]["text"]
if return_type == "thinking"
else content[1]["text"]
)
# Test code
if __name__ == "__main__":
from rich import print as rprint
helper = ConverseHelper()
fixed_prompt = Path("prompts/test_prompt.txt").read_text(
encoding="utf-8",
)
helper.add_message(fixed_prompt, role="user")
helper.add_cache_point()
inference_config = {
"temperature": 0.2,
"topP": 0.5,
"maxTokens": 10240,
}
assistant = BedrockAssistant(
model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
inference_config=inference_config,
enabled_reasoning=True,
enable_prompt_caching=True,
)
response = assistant.converse(
doc_path=Path("test_document.txt"),
messages=helper.get_messages(),
)
rprint(response)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment