Last active
March 1, 2025 00:38
-
-
Save mnadel/65b3de4888dc9355d7fbc5af36bbc13d to your computer and use it in GitHub Desktop.
OpenAI Assistants API w/ Function Calls
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
## requirements: | |
# beautifulsoup4 | |
# requests | |
# openai | |
# dotenv | |
import os | |
import sys | |
import time | |
import json | |
import signal | |
import logging | |
import argparse | |
import requests | |
from bs4 import BeautifulSoup | |
from bs4.element import Comment | |
from openai import OpenAI | |
from dotenv import load_dotenv | |
ASSISTANT_ID = "asst_abcdefghijkl" | |
# fetch the url, and remove html tags as to not exceed the llm's context | |
def text_from_url(url): | |
logging.info(f"fetching url={url}") | |
def is_tag_visible(element): | |
if element.parent.name in ["style", "script", "head", "title", "meta", "[document]"]: | |
return False | |
if isinstance(element, Comment): | |
return False | |
return True | |
response = requests.get(url, timeout=args.fetch_timeout) | |
if response.status_code != 200: | |
logging.error(f"failed to fetch {url}, code={response.status_code}, response={response.text}") | |
sys.exit(2) | |
soup = BeautifulSoup(response.text, "html.parser") | |
texts = soup.findAll(string=True) | |
visible_texts = filter(is_tag_visible, texts) | |
text = u" ".join(t.strip() for t in visible_texts) | |
logging.debug(f"fetched url={url}, text={text}") | |
return text | |
logger = logging.getLogger(__name__) | |
parser = argparse.ArgumentParser(description="query openai") | |
parser.add_argument("prompt", help="the llm prompt") | |
parser.add_argument("-t", "--timeout", type=int, default=60, help="timeout in seconds for assistant run") | |
parser.add_argument("-f", "--fetch-timeout", type=int, default=5, help="timeout in seconds for url fetch") | |
parser.add_argument("-p", "--poll", type=float, default=0.75, help="poll frequency") | |
parser.add_argument("-m", "--max-messages", type=int, default=10, help="number of messages to reteive from assistant run") | |
parser.add_argument("-v", "--verbose", action="store_true", default=False, help="verbose output") | |
parser.add_argument("-d", "--debug", action="store_true", default=False, help="debug output") | |
args = parser.parse_args() | |
if not args.prompt: | |
parser.exit(1, "prompt is required\n") | |
logging.basicConfig( | |
level=logging.DEBUG if args.debug else logging.INFO if args.verbose else logging.WARNING, | |
format="[%(asctime)s] %(levelname)s * %(message)s", | |
handlers=[logging.StreamHandler(sys.stderr)] | |
) | |
# load a .env file (where OPENAI_API_KEY might be) | |
load_dotenv() | |
# ignore ctrl-c and broken pipes | |
signal.signal(signal.SIGINT, signal.SIG_DFL) | |
signal.signal(signal.SIGPIPE, signal.SIG_DFL) | |
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
prompts = [args.prompt] | |
# if stdin is not a tty (i.e. it's a pipe), read it as a prompt | |
if not sys.stdin.isatty(): | |
stdin = str(sys.stdin.read()).strip() | |
prompts.append(stdin) | |
# a chat session with an assistant is represented by a thread | |
thread = client.beta.threads.create() | |
# add prompts to thread (assistants api currently only supports user roles) | |
for prompt in prompts: | |
client.beta.threads.messages.create( | |
thread_id=thread.id, | |
role="user", | |
content=prompt | |
) | |
# create an invocation of the assistant with our thread | |
run = client.beta.threads.runs.create( | |
thread_id=thread.id, | |
assistant_id=ASSISTANT_ID | |
) | |
start = time.time() | |
# poll until run is in a terminal state | |
while True: | |
run = client.beta.threads.runs.retrieve( | |
thread_id=thread.id, | |
run_id=run.id | |
) | |
logging.debug(f"assistant status={run.status}") | |
if run.status == "completed": | |
break | |
elif time.time() - start > args.timeout: | |
logging.error(f"timed out waiting for assistant run={json.dumps(run)}") | |
sys.exit(3) | |
elif run.status == "failed": | |
logging.error(f"assistant failed run={json.dumps(run)}") | |
sys.exit(4) | |
elif run.status == "expired": | |
logging.error(f"assistant expired run={json.dumps(run)}") | |
sys.exit(5) | |
elif run.status == "requires_action": | |
logging.info("assistant is ready for function call") | |
# this assistant has a single function (fetch_url), grab its parameter (the url) | |
params = run.required_action.submit_tool_outputs.tool_calls[0].function.arguments | |
parsed = json.loads(params) | |
url = parsed["url"] | |
# extract the text from the url, ie invoke the function | |
text = text_from_url(url) | |
# add the function's output back into the context | |
run = client.beta.threads.runs.submit_tool_outputs( | |
thread_id=thread.id, | |
run_id=run.id, | |
tool_outputs=[{ | |
"tool_call_id": run.required_action.submit_tool_outputs.tool_calls[0].id, | |
"output": text | |
}] | |
) | |
# wait before next poll | |
time.sleep(args.poll) | |
# the assistant completed, retrieve messages from the thread | |
messages = client.beta.threads.messages.list( | |
thread_id=thread.id, | |
order="desc", | |
limit=args.max_messages | |
) | |
# pluck out the text of the messages from the assistant | |
assistant_messages = [] | |
for msg in messages: | |
if msg.role == "assistant": | |
for content in msg.content: | |
if content.type == "text": | |
assistant_messages.append(content.text.value) | |
for message in reversed(assistant_messages): | |
print(message) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment