Created
January 21, 2025 20:06
-
-
Save aribornstein/131776e16c05c69af0b602e200fef09e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autogen_core import CancellationToken | |
from pydantic import BaseModel | |
from autogen_core.models import ChatCompletionClient, CreateResult, SystemMessage, UserMessage, AssistantMessage | |
from autogen_core.tools import Tool | |
from llama_cpp import Llama | |
from typing import List, Dict, Any, Literal, Optional, Sequence, Union, AsyncGenerator | |
import json | |
class ComponentModel(BaseModel): | |
provider: str | |
component_type: Optional[Literal["model", "agent", "tool", "termination", "token_provider"]] = None | |
version: Optional[int] = None | |
component_version: Optional[int] = None | |
description: Optional[str] = None | |
config: Dict[str, Any] | |
class LlamaCppChatCompletionClient(ChatCompletionClient): | |
def __init__(self, repo_id: str, filename: str, n_gpu_layers: int = -1, seed: int = 1337, n_ctx: int = 1000, verbose: bool = True): | |
""" | |
Initialize the LlamaCpp client. | |
""" | |
self.llm = Llama.from_pretrained( | |
repo_id=repo_id, | |
filename=filename, | |
n_gpu_layers=n_gpu_layers, | |
seed=seed, | |
n_ctx=n_ctx, | |
verbose=verbose, | |
) | |
self._total_usage = {"prompt_tokens": 0, "completion_tokens": 0} | |
async def create(self, messages: List[Any], tools: List[Any] = None, **kwargs) -> CreateResult: | |
""" | |
Generate a response using the model, incorporating tool metadata. | |
:param messages: A list of message objects to process. | |
:param tools: A list of tool objects to register dynamically. | |
:param kwargs: Additional arguments for the model. | |
:return: A CreateResult object containing the model's response. | |
""" | |
tools = tools or [] | |
# Convert LLMMessage objects to dictionaries with 'role' and 'content' | |
converted_messages = [] | |
for msg in messages: | |
if isinstance(msg, SystemMessage): | |
converted_messages.append({"role": "system", "content": msg.content}) | |
elif isinstance(msg, UserMessage): | |
converted_messages.append({"role": "user", "content": msg.content}) | |
elif isinstance(msg, AssistantMessage): | |
converted_messages.append({"role": "assistant", "content": msg.content}) | |
else: | |
raise ValueError(f"Unsupported message type: {type(msg)}") | |
# Add tool descriptions to the system message | |
tool_descriptions = "\n".join( | |
[f"Tool: {i+1}. {tool.name} - {tool.description}" for i, tool in enumerate(tools)] | |
) | |
few_shot_example = """ | |
Example tool usage: | |
User: How is the weather in Paris? | |
Assistant: Calling tool 'get_weather' with arguments: {"city": "Paris"} | |
""" | |
if tool_descriptions: | |
system_message = ( | |
"If you cannot answer the user's question directly, " | |
"or if the question aligns with one of the available tools, explicitly call the corresponding tool. " | |
"Provide tool arguments in JSON format when calling a tool like in the example below.\n" | |
f"{few_shot_example}\n" | |
"The following tools are available:\n" | |
f"{tool_descriptions}\n" | |
) | |
converted_messages.insert(0, {"role": "system", "content": system_message}) | |
response = self.llm.create_chat_completion(messages=converted_messages, stream=False) | |
response = self.llm.create_chat_completion(messages=converted_messages, stream=False) | |
self._total_usage["prompt_tokens"] += response.get("usage", {}).get("prompt_tokens", 0) | |
self._total_usage["completion_tokens"] += response.get("usage", {}).get("completion_tokens", 0) | |
# Parse the response | |
response_text = response["choices"][0]["message"]["content"] | |
# Detect tool usage in the response | |
tool_call = await self._detect_and_execute_tool(response_text, tools) | |
# Create a CreateResult object | |
create_result = CreateResult( | |
content=tool_call if tool_call else response_text, | |
usage=response.get("usage", {}), | |
finish_reason=response["choices"][0].get("finish_reason", "unknown"), | |
cached=False, # Defaulting to `False` as the result is not cached | |
) | |
return create_result | |
async def _detect_and_execute_tool(self, response_text: str, tools: List[Tool]) -> Optional[str]: | |
""" | |
Detect if the model is requesting a tool and execute the tool. | |
:param response_text: The raw response text from the model. | |
:param tools: A list of available tools. | |
:return: The result of the tool execution or None if no tool is called. | |
""" | |
for tool in tools: | |
if tool.name in response_text.lower(): | |
# Extract arguments (if any) from the response | |
func_args = self._extract_tool_arguments(response_text) | |
args_model = tool.args_type() | |
# # handle nested arguments | |
if "request" in args_model.__fields__: | |
func_args = {"request": func_args} | |
args_instance = args_model(**func_args) | |
# Execute the tool | |
try: | |
# Run the tool with extracted arguments | |
result = await tool.run(args=args_instance, cancellation_token=CancellationToken()) | |
# If result is a dict, serialize it; otherwise, convert to string | |
if isinstance(result, dict): | |
return json.dumps(result) | |
elif hasattr(result, "model_dump"): # If it's a Pydantic model | |
return json.dumps(result.model_dump()) | |
else: | |
return str(result) | |
except Exception as e: | |
return f"Error executing tool '{tool.name}': {e}" | |
return None | |
def _extract_tool_arguments(self, response_text: str) -> Dict[str, Any]: | |
""" | |
Extract tool arguments from the response text. | |
:param response_text: The raw response text. | |
:return: A dictionary of extracted arguments. | |
""" | |
try: | |
# Example: Look for JSON-like arguments in the response | |
args_start = response_text.find("{") | |
args_end = response_text.find("}") | |
if args_start != -1 and args_end != -1: | |
args_str = response_text[args_start:args_end + 1] | |
return json.loads(args_str) | |
except json.JSONDecodeError: | |
pass | |
return {} | |
async def create_stream(self, messages: List[Any], tools: List[Any] = None, **kwargs) -> AsyncGenerator[str, None]: | |
""" | |
Generate a streaming response using the model. | |
:param messages: A list of messages to process. | |
:param tools: A list of tool objects to register dynamically. | |
:param kwargs: Additional arguments for the model. | |
:return: An asynchronous generator yielding the response stream. | |
""" | |
tools = tools or [] | |
# Convert LLMMessage objects to dictionaries with 'role' and 'content' | |
converted_messages = [] | |
for msg in messages: | |
if isinstance(msg, SystemMessage): | |
converted_messages.append({"role": "system", "content": msg.content}) | |
elif isinstance(msg, UserMessage): | |
converted_messages.append({"role": "user", "content": msg.content}) | |
elif isinstance(msg, AssistantMessage): | |
converted_messages.append({"role": "assistant", "content": msg.content}) | |
else: | |
raise ValueError(f"Unsupported message type: {type(msg)}") | |
# Add tool descriptions to the system message | |
tool_descriptions = "\n".join( | |
[f"Tool: {tool.name} - {tool.description}" for tool in tools] | |
) | |
if tool_descriptions: | |
converted_messages.insert(0, {"role": "system", "content": f"The following tools are available:\n{tool_descriptions}"}) | |
# Convert messages into a plain string prompt | |
prompt = "\n".join(f"{msg['role']}: {msg['content']}" for msg in converted_messages) | |
# Call the model with streaming enabled | |
response_generator = self.llm(prompt=prompt, stream=True) | |
for token in response_generator: | |
yield token["choices"][0]["text"] | |
# Implement abstract methods | |
def actual_usage(self) -> Dict[str, int]: | |
return self._total_usage | |
@property | |
def capabilities(self) -> Dict[str, bool]: | |
return {"chat": True, "stream": True} | |
def count_tokens(self, messages: Sequence[Dict[str, Any]], **kwargs) -> int: | |
return sum(len(msg["content"].split()) for msg in messages) | |
@property | |
def model_info(self) -> Dict[str, Any]: | |
return { | |
"name": "llama-cpp", | |
"capabilities": {"chat": True, "stream": True}, | |
"context_window": self.llm.n_ctx, | |
"function_calling": True, | |
} | |
def remaining_tokens(self, messages: Sequence[Dict[str, Any]], **kwargs) -> int: | |
used_tokens = self.count_tokens(messages) | |
return max(self.llm.n_ctx - used_tokens, 0) | |
def total_usage(self) -> Dict[str, int]: | |
return self._total_usage |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment