Skip to content

Instantly share code, notes, and snippets.

@fsndzomga
Last active December 13, 2024 17:13
Show Gist options
  • Save fsndzomga/f48b4456e391d361e856d5562d4d96a9 to your computer and use it in GitHub Desktop.
Save fsndzomga/f48b4456e391d361e856d5562d4d96a9 to your computer and use it in GitHub Desktop.
import time
import json
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from pydantic import BaseModel, ValidationError, ConfigDict
from typing import Type, Optional
import os
from dotenv import load_dotenv
load_dotenv()
class MistralLanguageModel:
def __init__(self, api_key=os.environ["MISTRAL_API_KEY"],
model="mistral-tiny", temperature=0.5):
if api_key is None:
raise ValueError("The Mistral API KEY must be provided either as "
"an argument or as an environment variable named 'MISTRAL_API_KEY'") # noqa
self.api_key = api_key
self.model = model
self.temperature = temperature
self.client = MistralClient(api_key=self.api_key)
def generate(self, prompt: str,
output_format: Optional[Type[BaseModel]] = None,
max_tokens: int = None):
retry_delay = 0.1
while True:
try:
system_message = "You are a helpful assistant."
if output_format:
system_message += f" Respond in a JSON format that contains the following keys: {self._model_structure_repr(output_format)}" # noqa
messages = [
ChatMessage(role="system", content=system_message),
ChatMessage(role="user", content=prompt)
]
params = {
"model": self.model,
"messages": messages,
"temperature": self.temperature
}
if max_tokens is not None:
params["max_tokens"] = max_tokens
response = self.client.chat(**params)
response_content = response.choices[0].message.content
if output_format:
if self._is_valid_json_for_model(response_content,
output_format):
return response_content
else:
return response_content
except Exception:
print(f"Hit rate limit. Retrying in {retry_delay} seconds.")
time.sleep(retry_delay)
retry_delay *= 2
def _model_structure_repr(self, model: Type[BaseModel]) -> str:
fields = model.__annotations__
return ', '.join(f'{key}: {value}' for key, value in fields.items())
def _is_valid_json_for_model(self, text: str, model: Type[BaseModel]) -> bool: # noqa
"""
Check if a text is valid JSON and if it respects the provided BaseModel. # noqa
"""
model.model_config = ConfigDict(strict=True)
try:
parsed_data = json.loads(text)
model(**parsed_data)
return True
except (json.JSONDecodeError, ValidationError):
return False
class Output(BaseModel):
first_name: str
last_name: str
city: str
llm = MistralLanguageModel()
prompt = 'Extract the requested information from the following sentence: "Alice Johnson is visiting Rome."'
response = llm.generate(prompt, output_format=Output)
print(response)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment