Skip to content

Instantly share code, notes, and snippets.

@tomshaw
Created January 11, 2025 15:05
Show Gist options
  • Save tomshaw/8cf0b29eefe56abca5f4efd2be5e1204 to your computer and use it in GitHub Desktop.
Save tomshaw/8cf0b29eefe56abca5f4efd2be5e1204 to your computer and use it in GitHub Desktop.
Ollama structured output data generation example.
from ollama import chat
from pydantic import BaseModel
import pandas as pd
import argparse
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
# Initialize an empty DataFrame
data = pd.DataFrame(columns=["country", "capital", "languages"])
class Country(BaseModel):
name: str
capital: str
languages: list[str]
class CountryList(BaseModel):
countries: list[Country]
def main(model):
messages = [
{"role": "user", "content": "List names of all countries in an alphabetical order with capital and languages spoken."},
]
try:
response = chat(
model=model,
messages=messages,
format=CountryList.model_json_schema()
)
response = CountryList.model_validate_json(response.message.content)
logging.info("Data received from API: %s", response)
except Exception as e:
logging.error("Error occurred while fetching data from API: %s", e)
return
for i, country in enumerate(response.countries):
data.loc[i] = [country.name, country.capital, ",".join(country.languages)]
data.to_csv("countries.csv", header=None, index=False)
print(data)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate data using Ollama API")
parser.add_argument("--model", type=str, default="mistral", help="Model name to use for data generation (default: mistral)")
args = parser.parse_args()
main(args.model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment