Last active
March 8, 2025 16:57
-
-
Save BeautyyuYanli/9e66513e665ebaf3a4658e61fb9c04ef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is the workround before https://github.com/pydantic/pydantic-ai/issues/224 is fixed | |
from pydantic_ai.models.openai import ( | |
OpenAIModel, | |
OpenAIModelSettings, | |
ModelRequestParameters, | |
) | |
from itertools import chain | |
from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream | |
from openai.types import ChatModel, chat | |
from openai.types.chat import ChatCompletionChunk | |
from typing import Literal | |
class ProxyOpenAIModel(OpenAIModel): | |
async def _completions_create( # type: ignore | |
self, | |
messages: "list[ModelMessage]", # type: ignore | |
stream: bool, | |
model_settings: OpenAIModelSettings, | |
model_request_parameters: ModelRequestParameters, | |
) -> "chat.ChatCompletion | AsyncStream[ChatCompletionChunk]": | |
tools = self._get_tools(model_request_parameters) | |
# standalone function to make it easier to override | |
if not tools: | |
tool_choice: 'Literal["none", "required", "auto"] | None' = None | |
else: | |
tool_choice = "auto" | |
openai_messages = list(chain(*(self._map_message(m) for m in messages))) | |
return await self.client.chat.completions.create( | |
model=self._model_name, | |
messages=openai_messages, | |
n=1, | |
parallel_tool_calls=model_settings.get("parallel_tool_calls", NOT_GIVEN), | |
tools=tools or NOT_GIVEN, | |
tool_choice=tool_choice or NOT_GIVEN, | |
stream=stream, | |
stream_options={"include_usage": True} if stream else NOT_GIVEN, | |
max_tokens=model_settings.get("max_tokens", NOT_GIVEN), | |
temperature=model_settings.get("temperature", NOT_GIVEN), | |
top_p=model_settings.get("top_p", NOT_GIVEN), | |
timeout=model_settings.get("timeout", NOT_GIVEN), | |
seed=model_settings.get("seed", NOT_GIVEN), | |
presence_penalty=model_settings.get("presence_penalty", NOT_GIVEN), | |
frequency_penalty=model_settings.get("frequency_penalty", NOT_GIVEN), | |
logit_bias=model_settings.get("logit_bias", NOT_GIVEN), | |
reasoning_effort=model_settings.get("openai_reasoning_effort", NOT_GIVEN), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment