-
-
Save jimbrig/2ce3caa272b6c819f4784b9d4a1dc021 to your computer and use it in GitHub Desktop.
LLM Factory with Instructor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Dict, List, Type | |
import instructor | |
from anthropic import Anthropic | |
from config.settings import get_settings | |
from openai import OpenAI | |
from pydantic import BaseModel, Field | |
class LLMFactory: | |
def __init__(self, provider: str): | |
self.provider = provider | |
self.settings = getattr(get_settings(), provider) | |
self.client = self._initialize_client() | |
def _initialize_client(self) -> Any: | |
client_initializers = { | |
"openai": lambda s: instructor.from_openai(OpenAI(api_key=s.api_key)), | |
"anthropic": lambda s: instructor.from_anthropic( | |
Anthropic(api_key=s.api_key) | |
), | |
"llama": lambda s: instructor.from_openai( | |
OpenAI(base_url=s.base_url, api_key=s.api_key), | |
mode=instructor.Mode.JSON, | |
), | |
} | |
initializer = client_initializers.get(self.provider) | |
if initializer: | |
return initializer(self.settings) | |
raise ValueError(f"Unsupported LLM provider: {self.provider}") | |
def create_completion( | |
self, response_model: Type[BaseModel], messages: List[Dict[str, str]], **kwargs | |
) -> Any: | |
completion_params = { | |
"model": kwargs.get("model", self.settings.default_model), | |
"temperature": kwargs.get("temperature", self.settings.temperature), | |
"max_retries": kwargs.get("max_retries", self.settings.max_retries), | |
"max_tokens": kwargs.get("max_tokens", self.settings.max_tokens), | |
"response_model": response_model, | |
"messages": messages, | |
} | |
return self.client.chat.completions.create(**completion_params) | |
if __name__ == "__main__": | |
class CompletionModel(BaseModel): | |
response: str = Field(description="Your response to the user.") | |
reasoning: str = Field(description="Explain your reasoning for the response.") | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{ | |
"role": "user", | |
"content": "If it takes 2 hours to dry 1 shirt out in the sun, how long will it take to dry 5 shirts?", | |
}, | |
] | |
llm = LLMFactory("openai") | |
completion = llm.create_completion( | |
response_model=CompletionModel, | |
messages=messages, | |
) | |
assert isinstance(completion, CompletionModel) | |
print(f"Response: {completion.response}\n") | |
print(f"Reasoning: {completion.reasoning}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
from pydantic_settings import BaseSettings | |
from functools import lru_cache | |
from dotenv import load_dotenv | |
import os | |
load_dotenv() | |
class LLMProviderSettings(BaseSettings): | |
temperature: float = 0.0 | |
max_tokens: Optional[int] = None | |
max_retries: int = 3 | |
class OpenAISettings(LLMProviderSettings): | |
api_key: str = os.getenv("OPENAI_API_KEY") | |
default_model: str = "gpt-4o" | |
class AnthropicSettings(LLMProviderSettings): | |
api_key: str = os.getenv("ANTHROPIC_API_KEY") | |
default_model: str = "claude-3-5-sonnet-20240620" | |
max_tokens: int = 1024 | |
class LlamaSettings(LLMProviderSettings): | |
api_key: str = "key" # required, but not used | |
default_model: str = "llama3" | |
base_url: str = "http://localhost:11434/v1" | |
class Settings(BaseSettings): | |
app_name: str = "GenAI Project Template" | |
openai: OpenAISettings = OpenAISettings() | |
anthropic: AnthropicSettings = AnthropicSettings() | |
llama: LlamaSettings = LlamaSettings() | |
@lru_cache | |
def get_settings(): | |
return Settings() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment