Created
October 28, 2024 08:30
-
-
Save santolucito/728c2da6eff51113ddc4ad14a56594ab to your computer and use it in GitHub Desktop.
claude aoe2 computer use
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Agentic sampling loop that calls the Anthropic API and local implementation of anthropic-defined computer use tools. | |
basically the same as https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py | |
just with a few changes to reduce overhead | |
""" | |
import platform | |
from collections.abc import Callable | |
from datetime import datetime | |
from enum import StrEnum | |
from typing import Any, cast | |
import asyncio | |
import httpx | |
from anthropic import ( | |
Anthropic, | |
AnthropicBedrock, | |
AnthropicVertex, | |
APIError, | |
APIResponseValidationError, | |
APIStatusError, | |
) | |
from anthropic.types.beta import ( | |
BetaCacheControlEphemeralParam, | |
BetaContentBlockParam, | |
BetaImageBlockParam, | |
BetaMessage, | |
BetaMessageParam, | |
BetaTextBlock, | |
BetaTextBlockParam, | |
BetaToolResultBlockParam, | |
BetaToolUseBlockParam, | |
) | |
from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult | |
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" | |
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31" | |
class APIProvider(StrEnum): | |
ANTHROPIC = "anthropic" | |
BEDROCK = "bedrock" | |
VERTEX = "vertex" | |
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = { | |
APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022", | |
APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0", | |
APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022", | |
} | |
# This system prompt is optimized for the Docker environment in this repository and | |
# specific tool combinations enabled. | |
# We encourage modifying this system prompt to ensure the model has context for the | |
# environment it is running in, and to provide any additional information that may be | |
# helpful for the task at hand. | |
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY> | |
* You are utilising SteamLink to play a game of Age of Empires 2. | |
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. | |
</SYSTEM_CAPABILITY> | |
<IMPORTANT> | |
* Always ensure that actions you take are actually executed by checking the games stats at the top left of the screen. | |
</IMPORTANT>""" | |
async def sampling_loop( | |
*, | |
model: str, | |
provider: APIProvider, | |
system_prompt_suffix: str, | |
messages: list[BetaMessageParam], | |
output_callback: Callable[[BetaContentBlockParam], None], | |
tool_output_callback: Callable[[ToolResult, str], None], | |
api_response_callback: Callable[ | |
[httpx.Request, httpx.Response | object | None, Exception | None], None | |
], | |
api_key: str, | |
only_n_most_recent_images: int | None = None, | |
max_tokens: int = 4096, | |
max_message_pairs: int = 6, | |
): | |
""" | |
Agentic sampling loop for the assistant/tool interaction of computer use. | |
""" | |
tool_collection = ToolCollection( | |
ComputerTool(), | |
BashTool(), | |
EditTool(), | |
) | |
system = BetaTextBlockParam( | |
type="text", | |
text=f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}", | |
) | |
output_callback("Running AOE2 Player version 0.4") | |
while True: | |
# Trim message history before adding new messages | |
if len(messages) > max_message_pairs * 2: | |
messages = messages[-max_message_pairs * 2:] | |
enable_prompt_caching = False | |
betas = [COMPUTER_USE_BETA_FLAG] | |
image_truncation_threshold = 10 | |
if provider == APIProvider.ANTHROPIC: | |
client = Anthropic(api_key=api_key) | |
enable_prompt_caching = True | |
elif provider == APIProvider.VERTEX: | |
client = AnthropicVertex() | |
elif provider == APIProvider.BEDROCK: | |
client = AnthropicBedrock() | |
if enable_prompt_caching: | |
betas.append(PROMPT_CACHING_BETA_FLAG) | |
_inject_prompt_caching(messages) | |
# Is it ever worth it to bust the cache with prompt caching? | |
image_truncation_threshold = 50 | |
system["cache_control"] = {"type": "ephemeral"} | |
if only_n_most_recent_images: | |
_maybe_filter_to_n_most_recent_images( | |
messages, | |
only_n_most_recent_images, | |
min_removal_threshold=image_truncation_threshold, | |
) | |
# Sleep for 15 seconds to prevent rate limiting | |
import time | |
time.sleep(15) | |
# Call the API | |
# we use raw_response to provide debug information to streamlit. Your | |
# implementation may be able call the SDK directly with: | |
# `response = client.messages.create(...)` instead. | |
try: | |
raw_response = client.beta.messages.with_raw_response.create( | |
max_tokens=max_tokens, | |
messages=messages, | |
model=model, | |
system=[system], | |
tools=tool_collection.to_params(), | |
betas=betas, | |
) | |
except (APIStatusError, APIResponseValidationError) as e: | |
api_response_callback(e.request, e.response, e) | |
return messages | |
except APIError as e: | |
api_response_callback(e.request, e.body, e) | |
return messages | |
api_response_callback( | |
raw_response.http_response.request, raw_response.http_response, None | |
) | |
response = raw_response.parse() | |
response_params = _response_to_params(response) | |
messages.append( | |
{ | |
"role": "assistant", | |
"content": response_params, | |
} | |
) | |
tool_result_content: list[BetaToolResultBlockParam] = [] | |
for content_block in response_params: | |
output_callback(content_block) | |
# For typing text (only if content_block is a text block) | |
if content_block["type"] == "text": | |
# Split text into chunks of 100 characters | |
text = content_block["text"] | |
chunk_size = 100 | |
chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] | |
# Send each chunk separately | |
for chunk in chunks: | |
await tool_collection.run( | |
name="computer", | |
tool_input={"action": "key", "text": "Return"} | |
) | |
await tool_collection.run( | |
name="computer", | |
tool_input={"action": "type", "text": chunk} | |
) | |
# Add a small delay between chunks if needed | |
await asyncio.sleep(3) | |
await tool_collection.run( | |
name="computer", | |
tool_input={"action": "key", "text": "Return"} | |
) | |
if content_block["type"] == "tool_use": | |
result = await tool_collection.run( | |
name=content_block["name"], | |
tool_input=cast(dict[str, Any], content_block["input"]), | |
) | |
tool_result_content.append( | |
_make_api_tool_result(result, content_block["id"]) | |
) | |
tool_output_callback(result, content_block["id"]) | |
if not tool_result_content: | |
return messages | |
messages.append({"content": tool_result_content, "role": "user"}) | |
def _maybe_filter_to_n_most_recent_images( | |
messages: list[BetaMessageParam], | |
images_to_keep: int, | |
min_removal_threshold: int, | |
): | |
""" | |
With the assumption that images are screenshots that are of diminishing value as | |
the conversation progresses, remove all but the final `images_to_keep` tool_result | |
images in place, with a chunk of min_removal_threshold to reduce the amount we | |
break the implicit prompt cache. | |
""" | |
if images_to_keep is None: | |
return messages | |
tool_result_blocks = cast( | |
list[BetaToolResultBlockParam], | |
[ | |
item | |
for message in messages | |
for item in ( | |
message["content"] if isinstance(message["content"], list) else [] | |
) | |
if isinstance(item, dict) and item.get("type") == "tool_result" | |
], | |
) | |
total_images = sum( | |
1 | |
for tool_result in tool_result_blocks | |
for content in tool_result.get("content", []) | |
if isinstance(content, dict) and content.get("type") == "image" | |
) | |
images_to_remove = total_images - images_to_keep | |
# for better cache behavior, we want to remove in chunks | |
images_to_remove -= images_to_remove % min_removal_threshold | |
for tool_result in tool_result_blocks: | |
if isinstance(tool_result.get("content"), list): | |
new_content = [] | |
for content in tool_result.get("content", []): | |
if isinstance(content, dict) and content.get("type") == "image": | |
if images_to_remove > 0: | |
images_to_remove -= 1 | |
continue | |
new_content.append(content) | |
tool_result["content"] = new_content | |
def _response_to_params( | |
response: BetaMessage, | |
) -> list[BetaTextBlockParam | BetaToolUseBlockParam]: | |
res: list[BetaTextBlockParam | BetaToolUseBlockParam] = [] | |
for block in response.content: | |
if isinstance(block, BetaTextBlock): | |
res.append({"type": "text", "text": block.text}) | |
else: | |
res.append(cast(BetaToolUseBlockParam, block.model_dump())) | |
return res | |
def _inject_prompt_caching( | |
messages: list[BetaMessageParam], | |
): | |
""" | |
Set cache breakpoints for the 3 most recent turns | |
one cache breakpoint is left for tools/system prompt, to be shared across sessions | |
""" | |
breakpoints_remaining = 3 | |
for message in reversed(messages): | |
if message["role"] == "user" and isinstance( | |
content := message["content"], list | |
): | |
if breakpoints_remaining: | |
breakpoints_remaining -= 1 | |
content[-1]["cache_control"] = BetaCacheControlEphemeralParam( | |
{"type": "ephemeral"} | |
) | |
else: | |
content[-1].pop("cache_control", None) | |
# we'll only every have one extra turn per loop | |
break | |
def _make_api_tool_result( | |
result: ToolResult, tool_use_id: str | |
) -> BetaToolResultBlockParam: | |
"""Convert an agent ToolResult to an API ToolResultBlockParam.""" | |
tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = [] | |
is_error = False | |
if result.error: | |
is_error = True | |
tool_result_content = _maybe_prepend_system_tool_result(result, result.error) | |
else: | |
if result.output: | |
tool_result_content.append( | |
{ | |
"type": "text", | |
"text": _maybe_prepend_system_tool_result(result, result.output), | |
} | |
) | |
if result.base64_image: | |
tool_result_content.append( | |
{ | |
"type": "image", | |
"source": { | |
"type": "base64", | |
"media_type": "image/png", | |
"data": result.base64_image, | |
}, | |
} | |
) | |
return { | |
"type": "tool_result", | |
"content": tool_result_content, | |
"tool_use_id": tool_use_id, | |
"is_error": is_error, | |
} | |
def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str): | |
if result.system: | |
result_text = f"<system>{result.system}</system>\n{result_text}" | |
return result_text |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment