Created
October 5, 2024 00:54
-
-
Save lmolkova/be1c8b4eeb8e4c176687b074a51c0501 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import time | |
from typing import override | |
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor | |
#HTTPXClientInstrumentor().instrument() | |
from openai import AssistantEventHandler | |
from openai.types.beta.threads.runs import RunStep | |
from opentelemetry import trace | |
from opentelemetry.trace import get_tracer, StatusCode | |
from opentelemetry.util.types import Attributes | |
from opentelemetry.sdk.trace import TracerProvider | |
from opentelemetry.sdk.resources import Resource | |
from opentelemetry.sdk.trace.export import SimpleSpanProcessor | |
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter | |
#from opentelemetry.instrumentation.openai import OpenAIInstrumentor | |
#OpenAIInstrumentor().instrument() | |
#from opentelemetry.instrumentation.requests import RequestsInstrumentor | |
#from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor | |
from chat.settings import MODEL, OPENAI_CLIENT as openai | |
#from azure.monitor.opentelemetry.exporter import AzureMonitorLogExporter, AzureMonitorTraceExporter | |
from events import MyEventLoggerProvider | |
def configure_tracing() -> TracerProvider: | |
provider = TracerProvider(resource=Resource({"service.name": "assistant"})) | |
provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter())) | |
#provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter())) | |
#provider.add_span_processor(SimpleSpanProcessor(AzureMonitorTraceExporter())) | |
trace.set_tracer_provider(provider) | |
return provider | |
def wait_on_run(run, thread): | |
while run.status == "queued" or run.status == "in_progress": | |
run = retrieve_run(run.thread_id, run.id) | |
time.sleep(0.5) | |
return run | |
def show_json(obj): | |
print(json.loads(obj.model_dump_json())) | |
tracer_provider = configure_tracing() | |
tracer = tracer_provider.get_tracer(__name__) | |
#logger = logger_provider.get_logger(__name__) | |
@tracer.start_as_current_span("create_assistant " + MODEL) | |
def create_assistant(name, instructions, model, temperature=None, top_p=None): | |
span = trace.get_current_span() | |
set_common_attributes(span, "create_assistant") | |
span.set_attribute("gen_ai.request.model", model) | |
if temperature: | |
span.set_attribute("gen_ai.request.temperature", temperature) | |
if (top_p): | |
span.set_attribute("gen_ai.request.top_p", top_p) | |
span.set_attribute("gen_ai.assistant.name", name) | |
span.add_event("gen_ai.system.message", {"gen_ai.event.content": json.dumps({"content": instructions})} ) # todo tool definitions | |
assistant = do_with_error_reporting(span, openai.beta.assistants.create, name=name, instructions=instructions, model=model, temperature=temperature, top_p=top_p) | |
span.set_attribute("gen_ai.assistant.id", assistant.id) | |
return assistant | |
@tracer.start_as_current_span("create_thread") | |
def create_thread(): | |
span = trace.get_current_span() | |
set_common_attributes(span, "create_thread") | |
thread = do_with_error_reporting(span, openai.beta.threads.create) | |
span.set_attribute("gen_ai.thread.id", thread.id) | |
return thread | |
@tracer.start_as_current_span("create_message") | |
def create_message(thread_id, role, content): | |
span = trace.get_current_span() | |
set_common_attributes(span, "create_message", thread_id=thread_id) | |
span.add_event(f"gen_ai.{role}.message", {"gen_ai.event.content": json.dumps({"content": content})}) # todo multi-modality and list of messages | |
message=do_with_error_reporting(span, openai.beta.threads.messages.create, thread_id=thread_id, role=role, content=content) | |
span.set_attribute("gen_ai.thread.id", thread_id) | |
span.set_attribute("gen_ai.message.id", message.id) | |
return message | |
@tracer.start_as_current_span("create_run") | |
def create_run(thread_id, assistant_id): | |
span = trace.get_current_span() | |
set_common_attributes(span, "create_run", assistant_id=assistant_id, thread_id=thread_id) | |
run = do_with_error_reporting(span, openai.beta.threads.runs.create, thread_id=thread_id, assistant_id=assistant_id) | |
set_run_attributes(span, run) | |
span.add_event("gen_ai.system.message", {"gen_ai.event.content": json.dumps({"content": run.instructions})} ) # todo tool definitions | |
return run | |
@tracer.start_as_current_span("create_and_pool_run") | |
def create_and_poll_run(thread_id, assistant_id): | |
span = trace.get_current_span() | |
set_common_attributes(span, "create_and_poll_run", assistant_id=assistant_id, thread_id=thread_id) | |
run = do_with_error_reporting(span, openai.beta.threads.runs.create_and_poll, thread_id=thread_id, assistant_id=assistant_id) | |
set_run_attributes(span, run) | |
span.add_event("gen_ai.system.message", {"gen_ai.event.content": json.dumps({"content": run.instructions})} ) # todo tool definitions | |
return run | |
@tracer.start_as_current_span("submit_tool_output") | |
def submit_tool_outputs(thread_id, run_id, tool_outputs, run_span): | |
span = trace.get_current_span() | |
set_common_attributes(span, "submit_tool_output", thread_id=thread_id, run_id=run_id) | |
for tool_output in tool_outputs: | |
span.add_event("gen_ai.tool.message", {"gen_ai.event.content": json.dumps({"content": tool_output})} ) | |
try: | |
with openai.beta.threads.runs.submit_tool_outputs_stream( | |
run_id=run_id, | |
thread_id=thread_id, | |
tool_outputs=tool_outputs, | |
event_handler=EventHandler(run_span)) as stream: | |
stream.until_done() | |
except Exception as e: | |
span.set_attribute("error.type", e.__qualname__) | |
span.status = trace.Status(StatusCode.ERROR) | |
class EventHandler(AssistantEventHandler): | |
def __init__(self, span): | |
super().__init__() | |
self.span = span | |
@override | |
def on_run_completed(self, run) -> None: | |
print("RUN COMPLETED\n") | |
@override | |
def on_run_step_completed(self, run_step: RunStep) -> None: | |
print("RUN STEP COMPLETED\n") | |
@override | |
def on_end(self) -> None: | |
pass | |
@override | |
def on_event(self, event) -> None: | |
#print(f"ON EVENT {json.loads(event.model_dump_json())}\n") | |
if event.event == 'thread.run.requires_action': | |
run_id = event.data.id # Retrieve the run ID from the event data | |
self.handle_requires_action(event.data, run_id) | |
pass | |
@override | |
def on_run_step_done(self, run_step: RunStep) -> None: | |
print(f"RUN STEP DONE {json.loads(run_step.model_dump_json())}\n") | |
@override | |
def on_tool_call_done(self, tool_call) -> None: | |
self.span.add_event("gen_ai.assistant.message", {"gen_ai.event.content": | |
json.dumps({"tool_calls": [json.loads(tool_call.model_dump_json())]})} | |
) | |
@override | |
def on_exception(self, exception: Exception) -> None: | |
self.span.set_attribute("error.type", exception.__qualname__) | |
self.span.status = trace.Status(StatusCode.ERROR) | |
self.span.record_exception(exception) | |
@override | |
def on_timeout(self) -> None: | |
self.span.set_attribute("error.type", "timeout") | |
self.span.status = trace.Status(StatusCode.ERROR) | |
@override | |
def on_message_done(self, message) -> None: | |
report_message_event(self.span, message) | |
@override | |
def on_image_file_done(self, image_file) -> None: | |
print("ON IMAGE DONE\n") | |
def handle_requires_action(self, data, run_id): | |
tool_outputs = [] | |
for tool in data.required_action.submit_tool_outputs.tool_calls: | |
if tool.function.name == "get_current_temperature": | |
with tracer.start_as_current_span("get_current_temperature"): | |
tool_outputs.append({"tool_call_id": tool.id, "output": "57"}) | |
elif tool.function.name == "get_rain_probability": | |
with tracer.start_as_current_span("get_rain_probability"): | |
tool_outputs.append({"tool_call_id": tool.id, "output": "0.06"}) | |
# Submit all tool_outputs at the same time | |
submit_tool_outputs(self.current_run.thread_id, run_id, tool_outputs, self.span) | |
@tracer.start_as_current_span("stream_run") | |
def stream_run(assistant_id, thread_id): | |
span = trace.get_current_span() | |
set_common_attributes(span, "stream_run", assistant_id=assistant_id, thread_id=thread_id) | |
try: | |
with openai.beta.threads.runs.stream(assistant_id=assistant_id, | |
thread_id=thread_id, | |
event_handler=EventHandler(span)) as stream: | |
stream.until_done() | |
set_run_attributes(span, stream.current_run) | |
span.add_event("gen_ai.system.message", {"gen_ai.event.content": json.dumps({"content": stream.current_run.instructions})} ) # todo tool definitions | |
return stream.current_run | |
except Exception as e: | |
span.set_attribute("error.type", e.__qualname__) | |
span.status = trace.Status(StatusCode.ERROR) | |
def set_run_attributes(span, run): | |
span.set_attribute("gen_ai.run.id", run.id) | |
span.set_attribute("gen_ai.request.model", run.model) | |
span.set_attribute("gen_ai.run.status", run.status) | |
span.set_attribute("gen_ai.request.temperature", run.temperature) | |
span.set_attribute("gen_ai.request.top_p", run.top_p) | |
if run.usage: | |
span.set_attribute("gen_ai.response.input_tokens", run.usage.prompt_tokens) | |
span.set_attribute("gen_ai.response.output_tokens", run.usage.completion_tokens) | |
def set_common_attributes(span, operation_name, assistant_id=None, thread_id=None, run_id=None): | |
if assistant_id: | |
span.set_attribute("gen_ai.assistant.id", assistant_id) | |
if thread_id: | |
span.set_attribute("gen_ai.thread.id", thread_id) | |
if run_id: | |
span.set_attribute("gen_ai.thread.id", run_id) | |
span.set_attribute("gen_ai.operation.name", operation_name) | |
span.set_attribute("gen_ai.system", "openai") | |
span.set_attribute("server.address", "openai-shared.openai.azure.com") | |
@tracer.start_as_current_span("retrieve_run") | |
def retrieve_run(thread_id, run_id): | |
span = trace.get_current_span() | |
set_common_attributes(span, "retrieve_run", thread_id=thread_id, run_id=run_id) | |
run = do_with_error_reporting(span, openai.beta.threads.runs.retrieve, thread_id=thread_id, run_id=run_id) | |
set_run_attributes(span, run) | |
return run | |
def do_with_error_reporting(span, func, *args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
except Exception as e: | |
span.set_attribute("error.type", e.__qualname__) | |
span.status = trace.Status(StatusCode.ERROR) | |
raise e | |
@tracer.start_as_current_span("list_messages") | |
def list_messages(thread_id): | |
span = trace.get_current_span() | |
set_common_attributes(span, "list_messages", thread_id=thread_id) | |
messages = do_with_error_reporting(span, openai.beta.threads.messages.list, thread_id=thread_id) | |
for message in messages: | |
report_message_event(span, message) | |
return messages | |
def report_message_event(span, message): | |
message_attributes = { | |
"gen_ai.message.id": message.id, | |
"gen_ai.message.status": message.status, | |
"gen_ai.assistant.id": message.assistant_id, | |
"gen_ai.thread.id": message.thread_id, | |
"gen_ai.run.id": message.run_id, | |
"gen_ai.event.content": json.dumps({"content": message.content[0].text.value}), | |
} | |
span.add_event(f"gen_ai.{message.role}.message", message_attributes) | |
def simple_stream(): | |
assistant = create_assistant("Math Tutor", "You are a personal math tutor. Answer questions briefly, in a sentence or less.", MODEL) | |
thread = create_thread() | |
create_message(thread.id, "user", "I need to solve the equation `3x + 11 = 14`. Can you help me?") | |
return stream_run(assistant.id, thread.id) | |
def create_and_poll(): | |
assistant = create_assistant("Math Tutor", "You are a personal math tutor. Answer questions briefly, in a sentence or less.", MODEL) | |
thread = create_thread() | |
message = create_message(thread.id, "user", "I need to solve the equation `3x + 11 = 14`. Can you help me?") | |
return create_and_poll_run(thread.id, assistant.id) | |
def with_tool_calling(): | |
assistant = openai.beta.assistants.create( | |
instructions="You are a weather bot. Use the provided functions to answer questions.", | |
model=MODEL, | |
tools=[ | |
{ | |
"type": "function", | |
"function": { | |
"name": "get_current_temperature", | |
"description": "Get the current temperature for a specific location", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"location": { | |
"type": "string", | |
"description": "The city and state, e.g., San Francisco, CA" | |
}, | |
"unit": { | |
"type": "string", | |
"enum": ["Celsius", "Fahrenheit"], | |
"description": "The temperature unit to use. Infer this from the user's location." | |
} | |
}, | |
"required": ["location", "unit"] | |
} | |
} | |
}, | |
{ | |
"type": "function", | |
"function": { | |
"name": "get_rain_probability", | |
"description": "Get the probability of rain for a specific location", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"location": { | |
"type": "string", | |
"description": "The city and state, e.g., San Francisco, CA" | |
} | |
}, | |
"required": ["location"] | |
} | |
} | |
} | |
] | |
) | |
thread = create_thread() | |
message = create_message(thread.id, "user", "What's the weather in San Francisco today and the likelihood it'll rain?") | |
return stream_run(assistant.id, thread.id) | |
def main(): | |
run = with_tool_calling() | |
messages = list_messages(run.thread_id) | |
#show_json(messages) | |
with tracer.start_as_current_span("main"): | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment