Skip to content

Instantly share code, notes, and snippets.

@praveenc
Last active June 10, 2025 17:58
Show Gist options
  • Save praveenc/17008bcfb100f9d312cde6f3d4afec62 to your computer and use it in GitHub Desktop.
Save praveenc/17008bcfb100f9d312cde6f3d4afec62 to your computer and use it in GitHub Desktop.
Python util script to invoke Amazon Bedrock using ConverseAPI. Supports Document upload, Prompt Caching, Reasoning Enabled.
# /// script
# requires-python = ">=3.12.9"
# dependencies = [
# "loguru==0.7.3",
# "boto3==1.38.32",
# "rich==14.0.0",
# ]
# ///
import time
from pathlib import Path
from typing import Any, ClassVar, Literal
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
from loguru import logger
class ConverseMessageHelper:
def __init__(
self,
prompt: str | None = None,
role: Literal["user", "assistant"] = "user",
):
self.messages: list[dict[str, Any]] = []
self.role = role
if prompt:
self.add_message(prompt)
def add_message(self, content: str, role: Literal["user", "assistant"] = "user"):
"""Add a message to the conversation."""
self.role = role
if len(self.messages) == 0:
self.messages.append({"role": self.role, "content": [{"text": content}]})
return
self.messages[-1]["content"].append({"text": content})
def add_cache_point(self):
"""Add a cache point to the conversation."""
self.messages[-1]["content"].append({"cachePoint": {"type": "default"}})
def get_messages(self):
"""Get the conversation messages."""
return self.messages
class BedrockControlPlane:
def __init__(self, region_name: str = "us-west-2"):
self.client = boto3.client("bedrock", region_name=region_name)
self.inference_profiles_type: Literal["SYSTEM_DEFINED", "APPLICATION"] = (
"SYSTEM_DEFINED"
)
def list_models(self, provider: str | None = None):
"""
List Amazon Bedrock foundation models.
:param provider: Optional provider name to filter by
:return: List of models
"""
models = self.client.list_foundation_models()
models = [
model
for model in models["modelSummaries"]
if model["modelLifecycle"]["status"] == "ACTIVE"
]
if provider:
models = [
model
for model in models
if model["providerName"].strip().lower() == provider.strip().lower()
]
return models
def describe_model(self, model_id: str):
"""
Describe a foundation model.
:param model_id: Model ID
:return: Model description
"""
return self.client.get_foundation_model(
modelIdentifier=model_id,
)["modelDetails"]
def list_inference_profiles(self, provider: str | None = None):
"""
List Amazon Bedrock inference profiles.
:param provider: Optional provider name to filter by
:return: List of inference profiles
"""
try:
profiles: list = self.client.list_inference_profiles(
typeEquals=self.inference_profiles_type,
)["inferenceProfileSummaries"]
if provider:
logger.info(f"Filtering inference profiles by provider: {provider}")
profiles = [
profile
for profile in profiles
if provider.strip().lower()
in profile["inferenceProfileId"].strip().lower()
and profile["status"] == "ACTIVE"
]
return profiles
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning("No inference profiles found.")
return []
raise
class BedrockAssistant:
# Models that support prompt caching and their minimum token requirements
# Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models
PROMPT_CACHE_SUPPORTED_MODELS: ClassVar[dict[str, dict[str, int]]] = {
"anthropic.claude-opus-4-20250514-v1:0": {"min_tokens": 1024, "max_checkpoints": 4},
"anthropic.claude-sonnet-4-20250514-v1:0": {"min_tokens": 1024, "max_checkpoints": 4},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {"min_tokens": 1024, "max_checkpoints": 4},
"anthropic.claude-3-5-haiku-20241022-v1:0": {"min_tokens": 2048, "max_checkpoints": 4},
"anthropic.claude-3-5-sonnet-20241022-v2:0": {"min_tokens": 1024, "max_checkpoints": 4},
"amazon.nova-micro-v1:0": {"min_tokens": 1024, "max_checkpoints": 4},
"amazon.nova-lite-v1:0": {"min_tokens": 1024, "max_checkpoints": 4},
"amazon.nova-pro-v1:0": {"min_tokens": 1024, "max_checkpoints": 4},
}
def __init__( # noqa: PLR0913
self,
model_id: str,
inference_config: dict,
system_prompt: list[dict[str, str]] = [],
*,
enabled_reasoning: bool = False,
budget_tokens: int = 1024,
region_name: str = "us-west-2",
enable_prompt_caching: bool = True,
):
config = Config(
read_timeout=1000,
)
min_budget_tokens: ClassVar[int] = 1024
self.client = boto3.client(
"bedrock-runtime",
region_name=region_name,
config=config,
)
self.model_id = model_id
self.enabled_reasoning = enabled_reasoning
self.enable_prompt_caching = enable_prompt_caching
# Check if prompt caching is supported for this model
self.prompt_caching_supported = False
if self.model_id.startswith("us."):
model_base_id = model_id[3:]
else:
model_base_id = model_id
if model_base_id in self.PROMPT_CACHE_SUPPORTED_MODELS:
self.prompt_caching_supported = True
self.prompt_cache_min_tokens = self.PROMPT_CACHE_SUPPORTED_MODELS[model_base_id]["min_tokens"]
self.prompt_cache_max_checkpoints = self.PROMPT_CACHE_SUPPORTED_MODELS[model_base_id]["max_checkpoints"]
if enable_prompt_caching:
logger.info(f"Prompt caching enabled for model {model_id}")
logger.info(f"Minimum tokens per cache checkpoint: {self.prompt_cache_min_tokens}")
elif enable_prompt_caching:
logger.warning(f"Prompt caching requested but not supported for model {model_id}")
self.enable_prompt_caching = False
# Setup system prompt with cache point if applicable
if system_prompt:
if self.enable_prompt_caching and self.prompt_caching_supported:
# For supported models, system prompts can have cache points
self.system_prompt = [
{"text": system_prompt},
{"cachePoint": {"type": "default"}},
]
else:
self.system_prompt = [{"text": system_prompt}]
else:
self.system_prompt = []
self.inference_config = inference_config
if self.enabled_reasoning:
if budget_tokens < min_budget_tokens:
logger.warning(
f"Budget tokens must be greater than {min_budget_tokens}. Setting to {min_budget_tokens}.",
)
self.budget_tokens = min_budget_tokens
else:
self.budget_tokens = budget_tokens
# when reasoning is enabled
# temperature can be only set to 1 and top_p must be unset
# max_tokens must be greater than budget_tokens
if self.inference_config["temperature"] != 1:
logger.warning(
"Temperature must be set to 1 when reasoning is enabled. Setting to 1.",
)
self.inference_config["temperature"] = 1
if "topP" in self.inference_config:
logger.warning(
"Top P must be unset when reasoning is enabled. Unsetting.",
)
del self.inference_config["topP"]
if self.inference_config["maxTokens"] <= self.budget_tokens:
logger.warning(
f"Max tokens must be greater than budget tokens when reasoning is enabled. Setting to maxTokens={self.inference_config['maxTokens'] + 1024}.",
)
self.inference_config["maxTokens"] = (
self.inference_config["maxTokens"] + 1024
)
self.max_retries = 5
self.base_delay = 5
def set_inference_config(self, inference_config: dict):
"""Set the default inference config for subsequent requests."""
self.inference_config = inference_config
def prepare_document_message(
self,
doc_path: Path,
):
"""
Prepare a document message for the conversation.
:param doc_path: Path to the document file
:return: Document dictionary
"""
doc_bytes = doc_path.read_bytes()
# Support document formats are: "pdf" || "csv" || "doc" || "docx" || "xls" || "xlsx" || "html" || "txt" || "md"
# Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html
doc_format = doc_path.suffix[1:]
if doc_format not in [
"pdf",
"csv",
"doc",
"docx",
"xls",
"xlsx",
"html",
"txt",
"md",
]:
error_msg = f"Unsupported document format: {doc_format}"
logger.warning(
"Supported document formats are: pdf, csv, doc, docx, xls, xlsx, html, txt, md"
)
logger.error(error_msg)
raise ValueError(error_msg)
# The document file name can only contain alphanumeric characters, whitespace characters, hyphens, parentheses, and square brackets.
# The name can't contain more than one consecutive whitespace character.
doc_name = (
"".join(doc_path.name.split(" "))
.replace("\\", "-")
.replace("/", "-")
.replace("_", "-")
.replace(",", "-")
.replace(".", "-")
)
logger.info(f"using document name: {doc_name}")
document_dict = {
"name": doc_name,
"format": doc_format,
"source": {"bytes": doc_bytes},
}
return document_dict
def converse(
self,
messages: list,
doc_path: Path | None = None,
return_type: str = "text",
*, # Force keyword-only for boolean parameters
add_cache_point: bool = True,
):
"""
Send a converse request to Amazon Bedrock.
:param messages: List of message dictionaries with 'role', 'content'
:param doc_path: Optional document path to be sent to the model
:param return_type: Optional return type ('text' or 'thinking')
:param add_cache_point: Whether to add a cache point to the message for prompt caching
:return: Processed response
"""
attempt = 0
while attempt <= self.max_retries:
try:
# Prepare the request
if self.enabled_reasoning and self.budget_tokens > 0:
additional_model_request_fields = {
"reasoning_config": {
"type": "enabled",
"budget_tokens": self.budget_tokens,
},
}
else:
additional_model_request_fields = None
# Prepare document message
if doc_path:
document_dict = self.prepare_document_message(doc_path)
messages[-1]["content"].append({"document": document_dict})
logger.info(f"Invoking model: {self.model_id} using Converse API...")
response_data = self.client.converse(
modelId=self.model_id,
messages=messages,
inferenceConfig=self.inference_config,
additionalModelRequestFields=additional_model_request_fields,
system=self.system_prompt,
)
# Log cache token usage if prompt caching was used
if (self.enable_prompt_caching and
self.prompt_caching_supported and
"usage" in response_data):
usage = response_data["usage"]
logger.info(usage)
cache_read_tokens = usage.get("cacheReadInputTokens", 0)
cache_write_tokens = usage.get("cacheWriteInputTokens", 0)
if cache_read_tokens > 0 or cache_write_tokens > 0:
logger.info(f"Cache read tokens: {cache_read_tokens}, Cache write tokens: {cache_write_tokens}")
return self._process_response(response_data, return_type)
except ClientError as e:
if e.response["Error"]["Code"] == "ThrottlingException":
if attempt >= self.max_retries:
msg = "Max retries exceeded. Please try again later."
logger.error(msg)
raise RuntimeError(msg) from e
attempt += 1
delay = self.base_delay * (2**attempt)
logger.warning(
f"ThrottlingException: Retrying in {delay} seconds...",
)
time.sleep(delay)
continue # Continue to the next iteration of the outer loop
error_msg = f"Bedrock API error: {e.response['Error']}"
logger.error(error_msg)
raise RuntimeError(error_msg) from e
except Exception as e:
error_msg = "Unexpected error during Bedrock API call"
logger.error(f"{error_msg}: {e!s}")
raise RuntimeError(error_msg) from e
logger.error("Reached end of converse method without returning or raising")
return None
def _process_response(self, response, return_type: str = "text"):
"""Process the raw API response into a user-friendly format."""
content = response["output"]["message"]["content"]
if not self.enabled_reasoning:
return content[0]["text"]
return (
content[0]["reasoningContent"]["reasoningText"]["text"]
if return_type == "thinking"
else content[1]["text"]
)
# Un-comment to run tests
if __name__ == "__main__":
from rich import print as rprint
helper = ConverseMessageHelper()
# Prompt is organized as fixed and variable. Fixed portion is usually, IDENTITY, INSTRUCTIONS, RELEVANT CONTEXT etc.
fixed_prompt = Path("prompts/prompt-fixed.txt").read_text(encoding="utf-8")
# variable prompt is the same prompt but the dynamic portions for each call extracted to a separate file.
variable_prompt = Path("prompts/prompt-variable.txt").read_text(encoding="utf-8")
helper.add_message(fixed_prompt, role="user")
helper.add_cache_point()
helper.add_message(variable_prompt, role="user")
inference_config = {
"temperature": 0.2,
"topP": 0.5,
"maxTokens": 10240,
}
assistant = BedrockAssistant(
model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
inference_config=inference_config,
enabled_reasoning=True,
enable_prompt_caching=True,
)
response = assistant.converse(
messages=helper.get_messages(),
return_type="",
)
rprint(response)
# Control plane utility class to list models and Inference profiles
control_plane = BedrockControlPlane()
print(control_plane.list_models(provider="amazon"))
rprint(control_plane.list_inference_profiles(provider="anthropic"))
# rprint(helper.get_messages())
# helper.add_message("Hello, how are you?", role="user")
# helper.add_message("message 2", role="assistant")
# rprint(helper.get_messages())
# rprint(control_plane.describe_model("anthropic.claude-3-7-sonnet-20250219-v1:0"))
# doc_path = Path("docs/abalone-xgb-modelcard-1674518818-4cb8.pdf")
# helper2 = ConverseHelper()
# prompt = "Write a high-quality short summary of the attached document?"
# helper2.add_message(prompt, role="user")
# rprint(helper2.get_messages())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment