Skip to content

Instantly share code, notes, and snippets.

@santolucito
Created October 28, 2024 08:30
Show Gist options
  • Save santolucito/728c2da6eff51113ddc4ad14a56594ab to your computer and use it in GitHub Desktop.
Save santolucito/728c2da6eff51113ddc4ad14a56594ab to your computer and use it in GitHub Desktop.
claude aoe2 computer use
"""
Agentic sampling loop that calls the Anthropic API and local implementation of anthropic-defined computer use tools.
basically the same as https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py
just with a few changes to reduce overhead
"""
import platform
from collections.abc import Callable
from datetime import datetime
from enum import StrEnum
from typing import Any, cast
import asyncio
import httpx
from anthropic import (
Anthropic,
AnthropicBedrock,
AnthropicVertex,
APIError,
APIResponseValidationError,
APIStatusError,
)
from anthropic.types.beta import (
BetaCacheControlEphemeralParam,
BetaContentBlockParam,
BetaImageBlockParam,
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaTextBlockParam,
BetaToolResultBlockParam,
BetaToolUseBlockParam,
)
from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
class APIProvider(StrEnum):
ANTHROPIC = "anthropic"
BEDROCK = "bedrock"
VERTEX = "vertex"
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022",
APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
}
# This system prompt is optimized for the Docker environment in this repository and
# specific tool combinations enabled.
# We encourage modifying this system prompt to ensure the model has context for the
# environment it is running in, and to provide any additional information that may be
# helpful for the task at hand.
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
* You are utilising SteamLink to play a game of Age of Empires 2.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
</SYSTEM_CAPABILITY>
<IMPORTANT>
* Always ensure that actions you take are actually executed by checking the games stats at the top left of the screen.
</IMPORTANT>"""
async def sampling_loop(
*,
model: str,
provider: APIProvider,
system_prompt_suffix: str,
messages: list[BetaMessageParam],
output_callback: Callable[[BetaContentBlockParam], None],
tool_output_callback: Callable[[ToolResult, str], None],
api_response_callback: Callable[
[httpx.Request, httpx.Response | object | None, Exception | None], None
],
api_key: str,
only_n_most_recent_images: int | None = None,
max_tokens: int = 4096,
max_message_pairs: int = 6,
):
"""
Agentic sampling loop for the assistant/tool interaction of computer use.
"""
tool_collection = ToolCollection(
ComputerTool(),
BashTool(),
EditTool(),
)
system = BetaTextBlockParam(
type="text",
text=f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}",
)
output_callback("Running AOE2 Player version 0.4")
while True:
# Trim message history before adding new messages
if len(messages) > max_message_pairs * 2:
messages = messages[-max_message_pairs * 2:]
enable_prompt_caching = False
betas = [COMPUTER_USE_BETA_FLAG]
image_truncation_threshold = 10
if provider == APIProvider.ANTHROPIC:
client = Anthropic(api_key=api_key)
enable_prompt_caching = True
elif provider == APIProvider.VERTEX:
client = AnthropicVertex()
elif provider == APIProvider.BEDROCK:
client = AnthropicBedrock()
if enable_prompt_caching:
betas.append(PROMPT_CACHING_BETA_FLAG)
_inject_prompt_caching(messages)
# Is it ever worth it to bust the cache with prompt caching?
image_truncation_threshold = 50
system["cache_control"] = {"type": "ephemeral"}
if only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(
messages,
only_n_most_recent_images,
min_removal_threshold=image_truncation_threshold,
)
# Sleep for 15 seconds to prevent rate limiting
import time
time.sleep(15)
# Call the API
# we use raw_response to provide debug information to streamlit. Your
# implementation may be able call the SDK directly with:
# `response = client.messages.create(...)` instead.
try:
raw_response = client.beta.messages.with_raw_response.create(
max_tokens=max_tokens,
messages=messages,
model=model,
system=[system],
tools=tool_collection.to_params(),
betas=betas,
)
except (APIStatusError, APIResponseValidationError) as e:
api_response_callback(e.request, e.response, e)
return messages
except APIError as e:
api_response_callback(e.request, e.body, e)
return messages
api_response_callback(
raw_response.http_response.request, raw_response.http_response, None
)
response = raw_response.parse()
response_params = _response_to_params(response)
messages.append(
{
"role": "assistant",
"content": response_params,
}
)
tool_result_content: list[BetaToolResultBlockParam] = []
for content_block in response_params:
output_callback(content_block)
# For typing text (only if content_block is a text block)
if content_block["type"] == "text":
# Split text into chunks of 100 characters
text = content_block["text"]
chunk_size = 100
chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
# Send each chunk separately
for chunk in chunks:
await tool_collection.run(
name="computer",
tool_input={"action": "key", "text": "Return"}
)
await tool_collection.run(
name="computer",
tool_input={"action": "type", "text": chunk}
)
# Add a small delay between chunks if needed
await asyncio.sleep(3)
await tool_collection.run(
name="computer",
tool_input={"action": "key", "text": "Return"}
)
if content_block["type"] == "tool_use":
result = await tool_collection.run(
name=content_block["name"],
tool_input=cast(dict[str, Any], content_block["input"]),
)
tool_result_content.append(
_make_api_tool_result(result, content_block["id"])
)
tool_output_callback(result, content_block["id"])
if not tool_result_content:
return messages
messages.append({"content": tool_result_content, "role": "user"})
def _maybe_filter_to_n_most_recent_images(
messages: list[BetaMessageParam],
images_to_keep: int,
min_removal_threshold: int,
):
"""
With the assumption that images are screenshots that are of diminishing value as
the conversation progresses, remove all but the final `images_to_keep` tool_result
images in place, with a chunk of min_removal_threshold to reduce the amount we
break the implicit prompt cache.
"""
if images_to_keep is None:
return messages
tool_result_blocks = cast(
list[BetaToolResultBlockParam],
[
item
for message in messages
for item in (
message["content"] if isinstance(message["content"], list) else []
)
if isinstance(item, dict) and item.get("type") == "tool_result"
],
)
total_images = sum(
1
for tool_result in tool_result_blocks
for content in tool_result.get("content", [])
if isinstance(content, dict) and content.get("type") == "image"
)
images_to_remove = total_images - images_to_keep
# for better cache behavior, we want to remove in chunks
images_to_remove -= images_to_remove % min_removal_threshold
for tool_result in tool_result_blocks:
if isinstance(tool_result.get("content"), list):
new_content = []
for content in tool_result.get("content", []):
if isinstance(content, dict) and content.get("type") == "image":
if images_to_remove > 0:
images_to_remove -= 1
continue
new_content.append(content)
tool_result["content"] = new_content
def _response_to_params(
response: BetaMessage,
) -> list[BetaTextBlockParam | BetaToolUseBlockParam]:
res: list[BetaTextBlockParam | BetaToolUseBlockParam] = []
for block in response.content:
if isinstance(block, BetaTextBlock):
res.append({"type": "text", "text": block.text})
else:
res.append(cast(BetaToolUseBlockParam, block.model_dump()))
return res
def _inject_prompt_caching(
messages: list[BetaMessageParam],
):
"""
Set cache breakpoints for the 3 most recent turns
one cache breakpoint is left for tools/system prompt, to be shared across sessions
"""
breakpoints_remaining = 3
for message in reversed(messages):
if message["role"] == "user" and isinstance(
content := message["content"], list
):
if breakpoints_remaining:
breakpoints_remaining -= 1
content[-1]["cache_control"] = BetaCacheControlEphemeralParam(
{"type": "ephemeral"}
)
else:
content[-1].pop("cache_control", None)
# we'll only every have one extra turn per loop
break
def _make_api_tool_result(
result: ToolResult, tool_use_id: str
) -> BetaToolResultBlockParam:
"""Convert an agent ToolResult to an API ToolResultBlockParam."""
tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = []
is_error = False
if result.error:
is_error = True
tool_result_content = _maybe_prepend_system_tool_result(result, result.error)
else:
if result.output:
tool_result_content.append(
{
"type": "text",
"text": _maybe_prepend_system_tool_result(result, result.output),
}
)
if result.base64_image:
tool_result_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": result.base64_image,
},
}
)
return {
"type": "tool_result",
"content": tool_result_content,
"tool_use_id": tool_use_id,
"is_error": is_error,
}
def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str):
if result.system:
result_text = f"<system>{result.system}</system>\n{result_text}"
return result_text
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment