Created
January 11, 2025 15:05
-
-
Save tomshaw/8cf0b29eefe56abca5f4efd2be5e1204 to your computer and use it in GitHub Desktop.
Ollama structured output data generation example.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ollama import chat | |
from pydantic import BaseModel | |
import pandas as pd | |
import argparse | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
# Initialize an empty DataFrame | |
data = pd.DataFrame(columns=["country", "capital", "languages"]) | |
class Country(BaseModel): | |
name: str | |
capital: str | |
languages: list[str] | |
class CountryList(BaseModel): | |
countries: list[Country] | |
def main(model): | |
messages = [ | |
{"role": "user", "content": "List names of all countries in an alphabetical order with capital and languages spoken."}, | |
] | |
try: | |
response = chat( | |
model=model, | |
messages=messages, | |
format=CountryList.model_json_schema() | |
) | |
response = CountryList.model_validate_json(response.message.content) | |
logging.info("Data received from API: %s", response) | |
except Exception as e: | |
logging.error("Error occurred while fetching data from API: %s", e) | |
return | |
for i, country in enumerate(response.countries): | |
data.loc[i] = [country.name, country.capital, ",".join(country.languages)] | |
data.to_csv("countries.csv", header=None, index=False) | |
print(data) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Generate data using Ollama API") | |
parser.add_argument("--model", type=str, default="mistral", help="Model name to use for data generation (default: mistral)") | |
args = parser.parse_args() | |
main(args.model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment