Skip to content

Instantly share code, notes, and snippets.

@jimbrig
Forked from daveebbelaar/llm_factory.py
Created August 17, 2024 18:34
Show Gist options
  • Save jimbrig/2ce3caa272b6c819f4784b9d4a1dc021 to your computer and use it in GitHub Desktop.
Save jimbrig/2ce3caa272b6c819f4784b9d4a1dc021 to your computer and use it in GitHub Desktop.
LLM Factory with Instructor
from typing import Any, Dict, List, Type
import instructor
from anthropic import Anthropic
from config.settings import get_settings
from openai import OpenAI
from pydantic import BaseModel, Field
class LLMFactory:
def __init__(self, provider: str):
self.provider = provider
self.settings = getattr(get_settings(), provider)
self.client = self._initialize_client()
def _initialize_client(self) -> Any:
client_initializers = {
"openai": lambda s: instructor.from_openai(OpenAI(api_key=s.api_key)),
"anthropic": lambda s: instructor.from_anthropic(
Anthropic(api_key=s.api_key)
),
"llama": lambda s: instructor.from_openai(
OpenAI(base_url=s.base_url, api_key=s.api_key),
mode=instructor.Mode.JSON,
),
}
initializer = client_initializers.get(self.provider)
if initializer:
return initializer(self.settings)
raise ValueError(f"Unsupported LLM provider: {self.provider}")
def create_completion(
self, response_model: Type[BaseModel], messages: List[Dict[str, str]], **kwargs
) -> Any:
completion_params = {
"model": kwargs.get("model", self.settings.default_model),
"temperature": kwargs.get("temperature", self.settings.temperature),
"max_retries": kwargs.get("max_retries", self.settings.max_retries),
"max_tokens": kwargs.get("max_tokens", self.settings.max_tokens),
"response_model": response_model,
"messages": messages,
}
return self.client.chat.completions.create(**completion_params)
if __name__ == "__main__":
class CompletionModel(BaseModel):
response: str = Field(description="Your response to the user.")
reasoning: str = Field(description="Explain your reasoning for the response.")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "If it takes 2 hours to dry 1 shirt out in the sun, how long will it take to dry 5 shirts?",
},
]
llm = LLMFactory("openai")
completion = llm.create_completion(
response_model=CompletionModel,
messages=messages,
)
assert isinstance(completion, CompletionModel)
print(f"Response: {completion.response}\n")
print(f"Reasoning: {completion.reasoning}")
from typing import Optional
from pydantic_settings import BaseSettings
from functools import lru_cache
from dotenv import load_dotenv
import os
load_dotenv()
class LLMProviderSettings(BaseSettings):
temperature: float = 0.0
max_tokens: Optional[int] = None
max_retries: int = 3
class OpenAISettings(LLMProviderSettings):
api_key: str = os.getenv("OPENAI_API_KEY")
default_model: str = "gpt-4o"
class AnthropicSettings(LLMProviderSettings):
api_key: str = os.getenv("ANTHROPIC_API_KEY")
default_model: str = "claude-3-5-sonnet-20240620"
max_tokens: int = 1024
class LlamaSettings(LLMProviderSettings):
api_key: str = "key" # required, but not used
default_model: str = "llama3"
base_url: str = "http://localhost:11434/v1"
class Settings(BaseSettings):
app_name: str = "GenAI Project Template"
openai: OpenAISettings = OpenAISettings()
anthropic: AnthropicSettings = AnthropicSettings()
llama: LlamaSettings = LlamaSettings()
@lru_cache
def get_settings():
return Settings()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment