Skip to content

Instantly share code, notes, and snippets.

@aribornstein
Created January 21, 2025 20:06
Show Gist options
  • Save aribornstein/131776e16c05c69af0b602e200fef09e to your computer and use it in GitHub Desktop.
Save aribornstein/131776e16c05c69af0b602e200fef09e to your computer and use it in GitHub Desktop.
from autogen_core import CancellationToken
from pydantic import BaseModel
from autogen_core.models import ChatCompletionClient, CreateResult, SystemMessage, UserMessage, AssistantMessage
from autogen_core.tools import Tool
from llama_cpp import Llama
from typing import List, Dict, Any, Literal, Optional, Sequence, Union, AsyncGenerator
import json
class ComponentModel(BaseModel):
provider: str
component_type: Optional[Literal["model", "agent", "tool", "termination", "token_provider"]] = None
version: Optional[int] = None
component_version: Optional[int] = None
description: Optional[str] = None
config: Dict[str, Any]
class LlamaCppChatCompletionClient(ChatCompletionClient):
def __init__(self, repo_id: str, filename: str, n_gpu_layers: int = -1, seed: int = 1337, n_ctx: int = 1000, verbose: bool = True):
"""
Initialize the LlamaCpp client.
"""
self.llm = Llama.from_pretrained(
repo_id=repo_id,
filename=filename,
n_gpu_layers=n_gpu_layers,
seed=seed,
n_ctx=n_ctx,
verbose=verbose,
)
self._total_usage = {"prompt_tokens": 0, "completion_tokens": 0}
async def create(self, messages: List[Any], tools: List[Any] = None, **kwargs) -> CreateResult:
"""
Generate a response using the model, incorporating tool metadata.
:param messages: A list of message objects to process.
:param tools: A list of tool objects to register dynamically.
:param kwargs: Additional arguments for the model.
:return: A CreateResult object containing the model's response.
"""
tools = tools or []
# Convert LLMMessage objects to dictionaries with 'role' and 'content'
converted_messages = []
for msg in messages:
if isinstance(msg, SystemMessage):
converted_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, UserMessage):
converted_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AssistantMessage):
converted_messages.append({"role": "assistant", "content": msg.content})
else:
raise ValueError(f"Unsupported message type: {type(msg)}")
# Add tool descriptions to the system message
tool_descriptions = "\n".join(
[f"Tool: {i+1}. {tool.name} - {tool.description}" for i, tool in enumerate(tools)]
)
few_shot_example = """
Example tool usage:
User: How is the weather in Paris?
Assistant: Calling tool 'get_weather' with arguments: {"city": "Paris"}
"""
if tool_descriptions:
system_message = (
"If you cannot answer the user's question directly, "
"or if the question aligns with one of the available tools, explicitly call the corresponding tool. "
"Provide tool arguments in JSON format when calling a tool like in the example below.\n"
f"{few_shot_example}\n"
"The following tools are available:\n"
f"{tool_descriptions}\n"
)
converted_messages.insert(0, {"role": "system", "content": system_message})
response = self.llm.create_chat_completion(messages=converted_messages, stream=False)
response = self.llm.create_chat_completion(messages=converted_messages, stream=False)
self._total_usage["prompt_tokens"] += response.get("usage", {}).get("prompt_tokens", 0)
self._total_usage["completion_tokens"] += response.get("usage", {}).get("completion_tokens", 0)
# Parse the response
response_text = response["choices"][0]["message"]["content"]
# Detect tool usage in the response
tool_call = await self._detect_and_execute_tool(response_text, tools)
# Create a CreateResult object
create_result = CreateResult(
content=tool_call if tool_call else response_text,
usage=response.get("usage", {}),
finish_reason=response["choices"][0].get("finish_reason", "unknown"),
cached=False, # Defaulting to `False` as the result is not cached
)
return create_result
async def _detect_and_execute_tool(self, response_text: str, tools: List[Tool]) -> Optional[str]:
"""
Detect if the model is requesting a tool and execute the tool.
:param response_text: The raw response text from the model.
:param tools: A list of available tools.
:return: The result of the tool execution or None if no tool is called.
"""
for tool in tools:
if tool.name in response_text.lower():
# Extract arguments (if any) from the response
func_args = self._extract_tool_arguments(response_text)
args_model = tool.args_type()
# # handle nested arguments
if "request" in args_model.__fields__:
func_args = {"request": func_args}
args_instance = args_model(**func_args)
# Execute the tool
try:
# Run the tool with extracted arguments
result = await tool.run(args=args_instance, cancellation_token=CancellationToken())
# If result is a dict, serialize it; otherwise, convert to string
if isinstance(result, dict):
return json.dumps(result)
elif hasattr(result, "model_dump"): # If it's a Pydantic model
return json.dumps(result.model_dump())
else:
return str(result)
except Exception as e:
return f"Error executing tool '{tool.name}': {e}"
return None
def _extract_tool_arguments(self, response_text: str) -> Dict[str, Any]:
"""
Extract tool arguments from the response text.
:param response_text: The raw response text.
:return: A dictionary of extracted arguments.
"""
try:
# Example: Look for JSON-like arguments in the response
args_start = response_text.find("{")
args_end = response_text.find("}")
if args_start != -1 and args_end != -1:
args_str = response_text[args_start:args_end + 1]
return json.loads(args_str)
except json.JSONDecodeError:
pass
return {}
async def create_stream(self, messages: List[Any], tools: List[Any] = None, **kwargs) -> AsyncGenerator[str, None]:
"""
Generate a streaming response using the model.
:param messages: A list of messages to process.
:param tools: A list of tool objects to register dynamically.
:param kwargs: Additional arguments for the model.
:return: An asynchronous generator yielding the response stream.
"""
tools = tools or []
# Convert LLMMessage objects to dictionaries with 'role' and 'content'
converted_messages = []
for msg in messages:
if isinstance(msg, SystemMessage):
converted_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, UserMessage):
converted_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AssistantMessage):
converted_messages.append({"role": "assistant", "content": msg.content})
else:
raise ValueError(f"Unsupported message type: {type(msg)}")
# Add tool descriptions to the system message
tool_descriptions = "\n".join(
[f"Tool: {tool.name} - {tool.description}" for tool in tools]
)
if tool_descriptions:
converted_messages.insert(0, {"role": "system", "content": f"The following tools are available:\n{tool_descriptions}"})
# Convert messages into a plain string prompt
prompt = "\n".join(f"{msg['role']}: {msg['content']}" for msg in converted_messages)
# Call the model with streaming enabled
response_generator = self.llm(prompt=prompt, stream=True)
for token in response_generator:
yield token["choices"][0]["text"]
# Implement abstract methods
def actual_usage(self) -> Dict[str, int]:
return self._total_usage
@property
def capabilities(self) -> Dict[str, bool]:
return {"chat": True, "stream": True}
def count_tokens(self, messages: Sequence[Dict[str, Any]], **kwargs) -> int:
return sum(len(msg["content"].split()) for msg in messages)
@property
def model_info(self) -> Dict[str, Any]:
return {
"name": "llama-cpp",
"capabilities": {"chat": True, "stream": True},
"context_window": self.llm.n_ctx,
"function_calling": True,
}
def remaining_tokens(self, messages: Sequence[Dict[str, Any]], **kwargs) -> int:
used_tokens = self.count_tokens(messages)
return max(self.llm.n_ctx - used_tokens, 0)
def total_usage(self) -> Dict[str, int]:
return self._total_usage
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment