Created
March 13, 2024 13:24
-
-
Save descention/c1a09d6c52fdd142ccd39304f27bfd5d to your computer and use it in GitHub Desktop.
OllamaModelClient for Autogen
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import autogen | |
import ollama | |
import random | |
from ollama import ChatResponse | |
from autogen import AssistantAgent, ModelClient, UserProxyAgent | |
from types import SimpleNamespace | |
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent | |
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent | |
config_list_custom = [{ | |
"model":"mixtral", | |
"base_url": "http://192.168.80.193:11434/", | |
"model_client_cls":"OllamaModelClient" | |
}] | |
llm_config = {"config_list": config_list_custom, "stream": True, "cache_seed":42} | |
class OllamaModelClient(ModelClient): | |
def __init__(self, config, **kwargs) -> None: | |
#print(f"OllamaModelClient config: {config}") | |
self.config = config | |
self.client = ollama.Client(host=config["base_url"]) | |
def create(self, params) -> ModelClient.ModelClientResponseProtocol: | |
print(f"{params}") | |
# If streaming is enabled and has messages, then iterate over the chunks of the response. | |
if params.get("stream", False) and "messages" in params: | |
params = params.copy() | |
response_contents = [""] * params.get("n", 1) | |
#finish_reasons = [""] * params.get("n", 1) | |
completion_tokens = 0 | |
stream = self.client.chat(model=params["model"], messages=params["messages"], stream=params["stream"]) | |
# Set the terminal text color to green | |
print("\033[32m", end="") | |
for chunk in stream: | |
content = chunk["message"]["content"] | |
if content is not None: | |
print(content, end="", flush=True) | |
response_contents[0] += content | |
completion_tokens += 1 | |
else: | |
# print() | |
pass | |
print("\033[0m\n") | |
choice = SimpleNamespace() | |
choice.message = SimpleNamespace() | |
choice.message.content = response_contents[0].strip() | |
choice.message.function_call = None | |
ret = SimpleNamespace() | |
ret.model = params["model"] | |
ret.choices = [] | |
ret.choices.append(choice) | |
else: | |
params = params.copy() | |
params["stream"] = False | |
response = self.client.chat(model=params["model"], messages=params["messages"], stream=params["stream"]) | |
#print(response) | |
ret = SimpleNamespace() | |
ret.model = params["model"] | |
choice = SimpleNamespace() | |
choice.message = SimpleNamespace() | |
choice.message.content = response["message"]["content"].strip() | |
choice.message.function_call = None | |
ret.choices = [] | |
ret.choices.append(choice) | |
return ret | |
def message_retrieval(self, response): | |
choices = response.choices | |
return [choice.message.content for choice in choices] | |
def get_usage(self, response: ModelClient.ModelClientResponseProtocol): | |
return { | |
} | |
def cost(self, response) -> float: | |
response.cost = 0 | |
return 0 | |
assistant = RetrieveAssistantAgent( | |
name="assistant", | |
system_message="You are a helpful assistant.", | |
llm_config=llm_config, | |
) | |
user_proxy = RetrieveUserProxyAgent( | |
name="ragproxyagent", | |
retrieve_config={ | |
"task": "qa", | |
"docs_path": [ | |
"https://raw.githubusercontent.com/microsoft/autogen/main/README.md", | |
"https://raw.githubusercontent.com/microsoft/autogen/main/website/blog/2023-07-14-Local-LLMs/index.md", | |
#"https://raw.githubusercontent.com/microsoft/autogen/main/website/blog/2023-10-18-RetrieveChat/index.mdx", | |
"https://raw.githubusercontent.com/ollama/ollama/main/docs/api.md", | |
#"https://raw.githubusercontent.com/ollama/ollama-python/main/README.md", | |
], | |
}, | |
code_execution_config={ | |
"work_dir": "coding", | |
"use_docker": False, # Please set use_docker=True if docker is available to run the generated code. Using docker is safer than running the generated code directly. | |
}, | |
) | |
assistant.register_model_client(model_client_cls=OllamaModelClient) | |
user_proxy.initiate_chat(recipient=assistant, problem="I want to use the Ollama api in Autogen. What do I need to do?") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is incomplete, but could be used as a starting point for the OllamaModelClient. I think get_usage might be able to be filled in a little better.