Last active
June 10, 2025 17:58
-
-
Save praveenc/17008bcfb100f9d312cde6f3d4afec62 to your computer and use it in GitHub Desktop.
Python util script to invoke Amazon Bedrock using ConverseAPI. Supports Document upload, Prompt Caching, Reasoning Enabled.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.12.9" | |
| # dependencies = [ | |
| # "loguru==0.7.3", | |
| # "boto3==1.38.32", | |
| # "rich==14.0.0", | |
| # ] | |
| # /// | |
| import time | |
| from pathlib import Path | |
| from typing import Any, ClassVar, Literal | |
| import boto3 | |
| from botocore.config import Config | |
| from botocore.exceptions import ClientError | |
| from loguru import logger | |
| class ConverseMessageHelper: | |
| def __init__( | |
| self, | |
| prompt: str | None = None, | |
| role: Literal["user", "assistant"] = "user", | |
| ): | |
| self.messages: list[dict[str, Any]] = [] | |
| self.role = role | |
| if prompt: | |
| self.add_message(prompt) | |
| def add_message(self, content: str, role: Literal["user", "assistant"] = "user"): | |
| """Add a message to the conversation.""" | |
| self.role = role | |
| if len(self.messages) == 0: | |
| self.messages.append({"role": self.role, "content": [{"text": content}]}) | |
| return | |
| self.messages[-1]["content"].append({"text": content}) | |
| def add_cache_point(self): | |
| """Add a cache point to the conversation.""" | |
| self.messages[-1]["content"].append({"cachePoint": {"type": "default"}}) | |
| def get_messages(self): | |
| """Get the conversation messages.""" | |
| return self.messages | |
| class BedrockControlPlane: | |
| def __init__(self, region_name: str = "us-west-2"): | |
| self.client = boto3.client("bedrock", region_name=region_name) | |
| self.inference_profiles_type: Literal["SYSTEM_DEFINED", "APPLICATION"] = ( | |
| "SYSTEM_DEFINED" | |
| ) | |
| def list_models(self, provider: str | None = None): | |
| """ | |
| List Amazon Bedrock foundation models. | |
| :param provider: Optional provider name to filter by | |
| :return: List of models | |
| """ | |
| models = self.client.list_foundation_models() | |
| models = [ | |
| model | |
| for model in models["modelSummaries"] | |
| if model["modelLifecycle"]["status"] == "ACTIVE" | |
| ] | |
| if provider: | |
| models = [ | |
| model | |
| for model in models | |
| if model["providerName"].strip().lower() == provider.strip().lower() | |
| ] | |
| return models | |
| def describe_model(self, model_id: str): | |
| """ | |
| Describe a foundation model. | |
| :param model_id: Model ID | |
| :return: Model description | |
| """ | |
| return self.client.get_foundation_model( | |
| modelIdentifier=model_id, | |
| )["modelDetails"] | |
| def list_inference_profiles(self, provider: str | None = None): | |
| """ | |
| List Amazon Bedrock inference profiles. | |
| :param provider: Optional provider name to filter by | |
| :return: List of inference profiles | |
| """ | |
| try: | |
| profiles: list = self.client.list_inference_profiles( | |
| typeEquals=self.inference_profiles_type, | |
| )["inferenceProfileSummaries"] | |
| if provider: | |
| logger.info(f"Filtering inference profiles by provider: {provider}") | |
| profiles = [ | |
| profile | |
| for profile in profiles | |
| if provider.strip().lower() | |
| in profile["inferenceProfileId"].strip().lower() | |
| and profile["status"] == "ACTIVE" | |
| ] | |
| return profiles | |
| except ClientError as e: | |
| if e.response["Error"]["Code"] == "ResourceNotFoundException": | |
| logger.warning("No inference profiles found.") | |
| return [] | |
| raise | |
| class BedrockAssistant: | |
| # Models that support prompt caching and their minimum token requirements | |
| # Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models | |
| PROMPT_CACHE_SUPPORTED_MODELS: ClassVar[dict[str, dict[str, int]]] = { | |
| "anthropic.claude-opus-4-20250514-v1:0": {"min_tokens": 1024, "max_checkpoints": 4}, | |
| "anthropic.claude-sonnet-4-20250514-v1:0": {"min_tokens": 1024, "max_checkpoints": 4}, | |
| "anthropic.claude-3-7-sonnet-20250219-v1:0": {"min_tokens": 1024, "max_checkpoints": 4}, | |
| "anthropic.claude-3-5-haiku-20241022-v1:0": {"min_tokens": 2048, "max_checkpoints": 4}, | |
| "anthropic.claude-3-5-sonnet-20241022-v2:0": {"min_tokens": 1024, "max_checkpoints": 4}, | |
| "amazon.nova-micro-v1:0": {"min_tokens": 1024, "max_checkpoints": 4}, | |
| "amazon.nova-lite-v1:0": {"min_tokens": 1024, "max_checkpoints": 4}, | |
| "amazon.nova-pro-v1:0": {"min_tokens": 1024, "max_checkpoints": 4}, | |
| } | |
| def __init__( # noqa: PLR0913 | |
| self, | |
| model_id: str, | |
| inference_config: dict, | |
| system_prompt: list[dict[str, str]] = [], | |
| *, | |
| enabled_reasoning: bool = False, | |
| budget_tokens: int = 1024, | |
| region_name: str = "us-west-2", | |
| enable_prompt_caching: bool = True, | |
| ): | |
| config = Config( | |
| read_timeout=1000, | |
| ) | |
| min_budget_tokens: ClassVar[int] = 1024 | |
| self.client = boto3.client( | |
| "bedrock-runtime", | |
| region_name=region_name, | |
| config=config, | |
| ) | |
| self.model_id = model_id | |
| self.enabled_reasoning = enabled_reasoning | |
| self.enable_prompt_caching = enable_prompt_caching | |
| # Check if prompt caching is supported for this model | |
| self.prompt_caching_supported = False | |
| if self.model_id.startswith("us."): | |
| model_base_id = model_id[3:] | |
| else: | |
| model_base_id = model_id | |
| if model_base_id in self.PROMPT_CACHE_SUPPORTED_MODELS: | |
| self.prompt_caching_supported = True | |
| self.prompt_cache_min_tokens = self.PROMPT_CACHE_SUPPORTED_MODELS[model_base_id]["min_tokens"] | |
| self.prompt_cache_max_checkpoints = self.PROMPT_CACHE_SUPPORTED_MODELS[model_base_id]["max_checkpoints"] | |
| if enable_prompt_caching: | |
| logger.info(f"Prompt caching enabled for model {model_id}") | |
| logger.info(f"Minimum tokens per cache checkpoint: {self.prompt_cache_min_tokens}") | |
| elif enable_prompt_caching: | |
| logger.warning(f"Prompt caching requested but not supported for model {model_id}") | |
| self.enable_prompt_caching = False | |
| # Setup system prompt with cache point if applicable | |
| if system_prompt: | |
| if self.enable_prompt_caching and self.prompt_caching_supported: | |
| # For supported models, system prompts can have cache points | |
| self.system_prompt = [ | |
| {"text": system_prompt}, | |
| {"cachePoint": {"type": "default"}}, | |
| ] | |
| else: | |
| self.system_prompt = [{"text": system_prompt}] | |
| else: | |
| self.system_prompt = [] | |
| self.inference_config = inference_config | |
| if self.enabled_reasoning: | |
| if budget_tokens < min_budget_tokens: | |
| logger.warning( | |
| f"Budget tokens must be greater than {min_budget_tokens}. Setting to {min_budget_tokens}.", | |
| ) | |
| self.budget_tokens = min_budget_tokens | |
| else: | |
| self.budget_tokens = budget_tokens | |
| # when reasoning is enabled | |
| # temperature can be only set to 1 and top_p must be unset | |
| # max_tokens must be greater than budget_tokens | |
| if self.inference_config["temperature"] != 1: | |
| logger.warning( | |
| "Temperature must be set to 1 when reasoning is enabled. Setting to 1.", | |
| ) | |
| self.inference_config["temperature"] = 1 | |
| if "topP" in self.inference_config: | |
| logger.warning( | |
| "Top P must be unset when reasoning is enabled. Unsetting.", | |
| ) | |
| del self.inference_config["topP"] | |
| if self.inference_config["maxTokens"] <= self.budget_tokens: | |
| logger.warning( | |
| f"Max tokens must be greater than budget tokens when reasoning is enabled. Setting to maxTokens={self.inference_config['maxTokens'] + 1024}.", | |
| ) | |
| self.inference_config["maxTokens"] = ( | |
| self.inference_config["maxTokens"] + 1024 | |
| ) | |
| self.max_retries = 5 | |
| self.base_delay = 5 | |
| def set_inference_config(self, inference_config: dict): | |
| """Set the default inference config for subsequent requests.""" | |
| self.inference_config = inference_config | |
| def prepare_document_message( | |
| self, | |
| doc_path: Path, | |
| ): | |
| """ | |
| Prepare a document message for the conversation. | |
| :param doc_path: Path to the document file | |
| :return: Document dictionary | |
| """ | |
| doc_bytes = doc_path.read_bytes() | |
| # Support document formats are: "pdf" || "csv" || "doc" || "docx" || "xls" || "xlsx" || "html" || "txt" || "md" | |
| # Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html | |
| doc_format = doc_path.suffix[1:] | |
| if doc_format not in [ | |
| "pdf", | |
| "csv", | |
| "doc", | |
| "docx", | |
| "xls", | |
| "xlsx", | |
| "html", | |
| "txt", | |
| "md", | |
| ]: | |
| error_msg = f"Unsupported document format: {doc_format}" | |
| logger.warning( | |
| "Supported document formats are: pdf, csv, doc, docx, xls, xlsx, html, txt, md" | |
| ) | |
| logger.error(error_msg) | |
| raise ValueError(error_msg) | |
| # The document file name can only contain alphanumeric characters, whitespace characters, hyphens, parentheses, and square brackets. | |
| # The name can't contain more than one consecutive whitespace character. | |
| doc_name = ( | |
| "".join(doc_path.name.split(" ")) | |
| .replace("\\", "-") | |
| .replace("/", "-") | |
| .replace("_", "-") | |
| .replace(",", "-") | |
| .replace(".", "-") | |
| ) | |
| logger.info(f"using document name: {doc_name}") | |
| document_dict = { | |
| "name": doc_name, | |
| "format": doc_format, | |
| "source": {"bytes": doc_bytes}, | |
| } | |
| return document_dict | |
| def converse( | |
| self, | |
| messages: list, | |
| doc_path: Path | None = None, | |
| return_type: str = "text", | |
| *, # Force keyword-only for boolean parameters | |
| add_cache_point: bool = True, | |
| ): | |
| """ | |
| Send a converse request to Amazon Bedrock. | |
| :param messages: List of message dictionaries with 'role', 'content' | |
| :param doc_path: Optional document path to be sent to the model | |
| :param return_type: Optional return type ('text' or 'thinking') | |
| :param add_cache_point: Whether to add a cache point to the message for prompt caching | |
| :return: Processed response | |
| """ | |
| attempt = 0 | |
| while attempt <= self.max_retries: | |
| try: | |
| # Prepare the request | |
| if self.enabled_reasoning and self.budget_tokens > 0: | |
| additional_model_request_fields = { | |
| "reasoning_config": { | |
| "type": "enabled", | |
| "budget_tokens": self.budget_tokens, | |
| }, | |
| } | |
| else: | |
| additional_model_request_fields = None | |
| # Prepare document message | |
| if doc_path: | |
| document_dict = self.prepare_document_message(doc_path) | |
| messages[-1]["content"].append({"document": document_dict}) | |
| logger.info(f"Invoking model: {self.model_id} using Converse API...") | |
| response_data = self.client.converse( | |
| modelId=self.model_id, | |
| messages=messages, | |
| inferenceConfig=self.inference_config, | |
| additionalModelRequestFields=additional_model_request_fields, | |
| system=self.system_prompt, | |
| ) | |
| # Log cache token usage if prompt caching was used | |
| if (self.enable_prompt_caching and | |
| self.prompt_caching_supported and | |
| "usage" in response_data): | |
| usage = response_data["usage"] | |
| logger.info(usage) | |
| cache_read_tokens = usage.get("cacheReadInputTokens", 0) | |
| cache_write_tokens = usage.get("cacheWriteInputTokens", 0) | |
| if cache_read_tokens > 0 or cache_write_tokens > 0: | |
| logger.info(f"Cache read tokens: {cache_read_tokens}, Cache write tokens: {cache_write_tokens}") | |
| return self._process_response(response_data, return_type) | |
| except ClientError as e: | |
| if e.response["Error"]["Code"] == "ThrottlingException": | |
| if attempt >= self.max_retries: | |
| msg = "Max retries exceeded. Please try again later." | |
| logger.error(msg) | |
| raise RuntimeError(msg) from e | |
| attempt += 1 | |
| delay = self.base_delay * (2**attempt) | |
| logger.warning( | |
| f"ThrottlingException: Retrying in {delay} seconds...", | |
| ) | |
| time.sleep(delay) | |
| continue # Continue to the next iteration of the outer loop | |
| error_msg = f"Bedrock API error: {e.response['Error']}" | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) from e | |
| except Exception as e: | |
| error_msg = "Unexpected error during Bedrock API call" | |
| logger.error(f"{error_msg}: {e!s}") | |
| raise RuntimeError(error_msg) from e | |
| logger.error("Reached end of converse method without returning or raising") | |
| return None | |
| def _process_response(self, response, return_type: str = "text"): | |
| """Process the raw API response into a user-friendly format.""" | |
| content = response["output"]["message"]["content"] | |
| if not self.enabled_reasoning: | |
| return content[0]["text"] | |
| return ( | |
| content[0]["reasoningContent"]["reasoningText"]["text"] | |
| if return_type == "thinking" | |
| else content[1]["text"] | |
| ) | |
| # Un-comment to run tests | |
| if __name__ == "__main__": | |
| from rich import print as rprint | |
| helper = ConverseMessageHelper() | |
| # Prompt is organized as fixed and variable. Fixed portion is usually, IDENTITY, INSTRUCTIONS, RELEVANT CONTEXT etc. | |
| fixed_prompt = Path("prompts/prompt-fixed.txt").read_text(encoding="utf-8") | |
| # variable prompt is the same prompt but the dynamic portions for each call extracted to a separate file. | |
| variable_prompt = Path("prompts/prompt-variable.txt").read_text(encoding="utf-8") | |
| helper.add_message(fixed_prompt, role="user") | |
| helper.add_cache_point() | |
| helper.add_message(variable_prompt, role="user") | |
| inference_config = { | |
| "temperature": 0.2, | |
| "topP": 0.5, | |
| "maxTokens": 10240, | |
| } | |
| assistant = BedrockAssistant( | |
| model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", | |
| inference_config=inference_config, | |
| enabled_reasoning=True, | |
| enable_prompt_caching=True, | |
| ) | |
| response = assistant.converse( | |
| messages=helper.get_messages(), | |
| return_type="", | |
| ) | |
| rprint(response) | |
| # Control plane utility class to list models and Inference profiles | |
| control_plane = BedrockControlPlane() | |
| print(control_plane.list_models(provider="amazon")) | |
| rprint(control_plane.list_inference_profiles(provider="anthropic")) | |
| # rprint(helper.get_messages()) | |
| # helper.add_message("Hello, how are you?", role="user") | |
| # helper.add_message("message 2", role="assistant") | |
| # rprint(helper.get_messages()) | |
| # rprint(control_plane.describe_model("anthropic.claude-3-7-sonnet-20250219-v1:0")) | |
| # doc_path = Path("docs/abalone-xgb-modelcard-1674518818-4cb8.pdf") | |
| # helper2 = ConverseHelper() | |
| # prompt = "Write a high-quality short summary of the attached document?" | |
| # helper2.add_message(prompt, role="user") | |
| # rprint(helper2.get_messages()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment