Created
          July 2, 2025 02:51 
        
      - 
      
 - 
        
Save timathom/fcb8f8cf7851659e0b8b08ac924fb8d2 to your computer and use it in GitHub Desktop.  
    domain-classification-with-mistral.ipynb
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyODSID537+99KNPE5/enAU9", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/timathom/fcb8f8cf7851659e0b8b08ac924fb8d2/domain-classification-with-mistral.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Library Dataset Domain Classification with Mistral Classifier Factory\n", | |
| "\n", | |
| "This notebook demonstrates how to:\n", | |
| "1. Load labeled training datasets for domain classification\n", | |
| "2. Use the Mistral Classifier Factory (with ministral-3b-latest) to produce a fine-tuned model\n", | |
| "3. Test the fine-tuned model for accuracy\n", | |
| "\n", | |
| "The goal is to create a model that can classify catalog records by domain of activity to aid in entity resolution and name disambiguation tasks." | |
| ], | |
| "metadata": { | |
| "id": "r1C8n6m-0c6R" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 1: Install Required Libraries\n", | |
| "\n", | |
| "We need:\n", | |
| "- `MistralAI`: To access the `ministral-3b` model for classifier fine-tuning\n", | |
| "- `datasets`: Hugging Face library for handling datasets\n", | |
| "- `wandb`: To access the Weights & Biases Mistral integration" | |
| ], | |
| "metadata": { | |
| "id": "pox6cLQn1ypQ" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Install required packages\n", | |
| "!pip install mistralai pandas matplotlib seaborn wandb datasets==3.2.0" | |
| ], | |
| "metadata": { | |
| "id": "KZuzgP5aXN-Q" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import os\n", | |
| "from google.colab import userdata\n", | |
| "import requests\n", | |
| "import json\n", | |
| "import random\n", | |
| "import time\n", | |
| "import pandas as pd\n", | |
| "from mistralai import Mistral\n", | |
| "from datasets import load_dataset\n", | |
| "from huggingface_hub import hf_hub_download\n", | |
| "import wandb\n", | |
| "RANDOM_SEED = 42" | |
| ], | |
| "metadata": { | |
| "id": "I_xTwuybWA2M" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 2: Configure API Keys\n", | |
| "\n", | |
| "Using Colab's secure userdata to access API keys for:\n", | |
| "- **Mistral.AI**: For fine-tuning\n", | |
| "- **Hugging Face**: To download datasets\n", | |
| "- **Weights & Biases**: For experiment tracking during training" | |
| ], | |
| "metadata": { | |
| "id": "ssfc9uc23Euc" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')\n", | |
| "os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n", | |
| "os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')\n", | |
| "os.environ['MISTRAL_API_KEY'] = userdata.get('MISTRAL_API_KEY')\n", | |
| "os.environ['MISTRAL_CLASSIFIER'] = userdata.get('MISTRAL_CLASSIFIER')" | |
| ], | |
| "metadata": { | |
| "id": "aI0wI4ai3crw" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 3: Load Source Datasets\n", | |
| "\n", | |
| "Load two datasets:\n", | |
| " with domain classifications corresponding to a dataset of\n", | |
| "\n", | |
| "1. Labeled dataset of person entities with associated bibliographic metadata\n", | |
| "2. Labeled dataset of domain classifications corresponding to each record in the person entities dataset" | |
| ], | |
| "metadata": { | |
| "id": "apHynNLI3mBY" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Load datasets\n", | |
| "training_data = pd.DataFrame(load_dataset(\"timathom/yale-library-entity-resolver-training-data\")[\"train\"])\n", | |
| "classifications = pd.DataFrame(load_dataset(\"timathom/yale-library-entity-resolver-classifications\")[\"train\"])" | |
| ], | |
| "metadata": { | |
| "id": "4miK3DljqJc9" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "odstnTk1JFqB", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "f4b2ef9c-f02f-46d1-d923-3447ade81021" | |
| }, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "📚 Loaded Labeled Datasets\n", | |
| "==================================================\n", | |
| "\n", | |
| "Training DataFrame has 2539 rows\n", | |
| "\n", | |
| "First row:\n", | |
| "identity 9.1\n", | |
| "composite Title: Archäologie und Photographie: fünfzig...\n", | |
| "marcKey 7001 $aSchubert, Franz.\n", | |
| "person Schubert, Franz\n", | |
| "roles Contributor\n", | |
| "title Archäologie und Photographie: fünfzig Beispi...\n", | |
| "attribution ausgewählt von Franz Schubert und Susanne Gru...\n", | |
| "provision Mainz: P. von Zabern, 1978\n", | |
| "subjects Photography in archaeology\n", | |
| "genres None\n", | |
| "relatedWork None\n", | |
| "recordId 53144\n", | |
| "personId 53144#Agent700-22\n", | |
| "Name: 0, dtype: object\n", | |
| "Classification DataFrame has 2539 rows\n", | |
| "\n", | |
| "First row:\n", | |
| "personId 53144#Agent700-22\n", | |
| "label [Documentary and Technical Arts, History, Heri...\n", | |
| "path [Arts, Culture, and Creative Expression > Docu...\n", | |
| "rationale This catalog entry describes Franz Schubert as...\n", | |
| "Name: 0, dtype: object\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "if training_data is not None and not training_data.empty \\\n", | |
| "and classifications is not None and not classifications.empty:\n", | |
| " print(\"📚 Loaded Labeled Datasets\")\n", | |
| " print(\"=\" * 50)\n", | |
| "\n", | |
| " print(f\"\\nTraining DataFrame has {len(training_data)} rows\\n\")\n", | |
| " first_row_training = training_data.iloc[0]\n", | |
| " print(\"First row:\")\n", | |
| " print(first_row_training)\n", | |
| "\n", | |
| " print(f\"Classification DataFrame has {len(classifications)} rows\\n\")\n", | |
| " first_row_class = classifications.iloc[0]\n", | |
| " print(\"First row:\")\n", | |
| " print(first_row_class)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 4: Connect to Mistral" | |
| ], | |
| "metadata": { | |
| "id": "R__Ja6-P5826" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Initialize Mistral client\n", | |
| "client = Mistral(api_key=os.environ['MISTRAL_API_KEY'])\n", | |
| "print(\"🤖 Mistral client initialized\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "-aXqOvMzXIV9", | |
| "outputId": "5dfc4625-ddc9-4ca2-93c5-c6c4cc14c834" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "🤖 Mistral client initialized\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 5: Prepare the Data\n", | |
| "\n", | |
| "Merge the two datasets and convert them to Mistral's expected format." | |
| ], | |
| "metadata": { | |
| "id": "q_klXTSz6Hvh" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Create entity lookup\n", | |
| "entity_lookup = {}\n", | |
| "for _, row in training_data.iterrows():\n", | |
| " person_id = str(row['personId'])\n", | |
| " entity_lookup[person_id] = row['composite']\n", | |
| "\n", | |
| "print(f\"Created entity lookup for {len(entity_lookup)} entities\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "cAigzXIrXxJi", | |
| "outputId": "5d446cae-d220-46c9-b646-5b573c1b6d2c" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Created entity lookup for 2539 entities\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Convert to Mistral format\n", | |
| "training_examples = []\n", | |
| "\n", | |
| "for idx, row in classifications.iterrows():\n", | |
| " person_id = row.get('personId', idx) # Use personId column or index\n", | |
| "\n", | |
| " # Get composite text\n", | |
| " composite_text = entity_lookup.get(person_id)\n", | |
| " if not composite_text:\n", | |
| " continue\n", | |
| "\n", | |
| " # Extract labels and parent categories\n", | |
| " labels_list = row.get('label', [])\n", | |
| " paths_list = row.get('path', [])\n", | |
| "\n", | |
| " if not labels_list:\n", | |
| " continue\n", | |
| "\n", | |
| " # Extract parent categories from paths\n", | |
| " parent_categories = []\n", | |
| " for path in paths_list:\n", | |
| " if \" > \" in path:\n", | |
| " parent_categories.append(path.split(\" > \")[0])\n", | |
| "\n", | |
| " # Create training example in Mistral format\n", | |
| " training_examples.append({\n", | |
| " \"text\": composite_text,\n", | |
| " \"labels\": {\n", | |
| " \"domain\": labels_list, # Multi-label list\n", | |
| " \"parent_category\": parent_categories\n", | |
| " }\n", | |
| " })\n", | |
| "\n", | |
| "print(f\"Created {len(training_examples)} training examples\")\n", | |
| "\n", | |
| "# Show sample\n", | |
| "print(\"\\n📝 Sample training example:\")\n", | |
| "sample_ex = training_examples[0]\n", | |
| "print(f\"Text: {sample_ex['text'][:500]}\")\n", | |
| "print(f\"Domains: {sample_ex['labels']['domain']}\")\n", | |
| "print(f\"Parents: {sample_ex['labels']['parent_category']}\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "P1kKqj__X56t", | |
| "outputId": "a463ff2c-1059-4b3b-b447-a804b841e349" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Created 2539 training examples\n", | |
| "\n", | |
| "📝 Sample training example:\n", | |
| "Text: Title: Archäologie und Photographie: fünfzig Beispiele zur Geschichte und Methode\n", | |
| "Subjects: Photography in archaeology\n", | |
| "Provision information: Mainz: P. von Zabern, 1978\n", | |
| "Domains: ['Documentary and Technical Arts', 'History, Heritage, and Memory']\n", | |
| "Parents: ['Arts, Culture, and Creative Expression', 'Humanities, Thought, and Interpretation']\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 6: Create Local Data Splits\n", | |
| "\n", | |
| "Create local train/validation splits (80/20) from the merged data for the fine-tuning process." | |
| ], | |
| "metadata": { | |
| "id": "bhBcEqrU6_BT" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Split data (80% train, 20% validation)\n", | |
| "random.seed(42)\n", | |
| "random.shuffle(training_examples)\n", | |
| "\n", | |
| "split_idx = int(len(training_examples) * 0.8)\n", | |
| "train_examples = training_examples[:split_idx]\n", | |
| "val_examples = training_examples[split_idx:]\n", | |
| "\n", | |
| "print(f\"Training set: {len(train_examples)} examples\")\n", | |
| "print(f\"Validation set: {len(val_examples)} examples\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "ulEtxM25ZPer", | |
| "outputId": "3649ec1e-9f9e-492e-e0a3-8820ad5a9937" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Training set: 2031 examples\n", | |
| "Validation set: 508 examples\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 7: Save the Data as JSON-L to Upload to Mistral" | |
| ], | |
| "metadata": { | |
| "id": "iW9s9jKg7TO-" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Save to JSONL files\n", | |
| "def save_jsonl(examples, filepath):\n", | |
| " with open(filepath, 'w', encoding='utf-8') as f:\n", | |
| " for example in examples:\n", | |
| " f.write(json.dumps(example, ensure_ascii=False) + '\\n')\n", | |
| " print(f\"Saved {len(examples)} examples to {filepath}\")\n", | |
| "\n", | |
| "os.makedirs(\"mistral\", exist_ok=True)\n", | |
| "\n", | |
| "train_path = \"./mistral/mistral_train_2025-07-01.jsonl\"\n", | |
| "val_path = \"./mistral/mistral_val_2025-07-01.jsonl\"\n", | |
| "\n", | |
| "save_jsonl(train_examples, train_path)\n", | |
| "save_jsonl(val_examples, val_path)\n", | |
| "\n", | |
| "print(\"✅ Data preparation complete!\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "RDsLQYAIZppw", | |
| "outputId": "8d7b6f48-af9b-411d-ddd0-df6a3b95faa6" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Saved 2031 examples to ./mistral/mistral_train_2025-07-01.jsonl\n", | |
| "Saved 508 examples to ./mistral/mistral_val_2025-07-01.jsonl\n", | |
| "✅ Data preparation complete!\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 8: Upload the Data and Retrieve the File IDs" | |
| ], | |
| "metadata": { | |
| "id": "9filfR3U7duv" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Upload the training data\n", | |
| "print(\"📤 Uploading training data...\")\n", | |
| "training_data = client.files.upload(\n", | |
| " file={\n", | |
| " \"file_name\": \"mistral_train_2025-07-01.jsonl\",\n", | |
| " \"content\": open(train_path, \"rb\"),\n", | |
| " }\n", | |
| ")\n", | |
| "print(f\"✅ Training file uploaded: {training_data.id}\")\n", | |
| "\n", | |
| "# Upload the validation data\n", | |
| "print(\"📤 Uploading validation data...\")\n", | |
| "validation_data = client.files.upload(\n", | |
| " file={\n", | |
| " \"file_name\": \"mistral_val_2025-07-01.jsonl\",\n", | |
| " \"content\": open(val_path, \"rb\"),\n", | |
| " }\n", | |
| ")\n", | |
| "print(f\"✅ Validation file uploaded: {validation_data.id}\")\n", | |
| "\n", | |
| "print(\"\\n📋 File IDs:\")\n", | |
| "print(f\"Training: {training_data.id}\")\n", | |
| "print(f\"Validation: {validation_data.id}\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "355f9QY8arT5", | |
| "outputId": "aa326a57-c070-45d7-f26b-64557699aee6" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "📤 Uploading training data...\n", | |
| "✅ Training file uploaded: d55689f1-ba6b-4cc9-8011-3f1e833d5ef6\n", | |
| "📤 Uploading validation data...\n", | |
| "✅ Validation file uploaded: 5c164b8c-2ed2-4840-bd1f-fb9a151615d7\n", | |
| "\n", | |
| "📋 File IDs:\n", | |
| "Training: d55689f1-ba6b-4cc9-8011-3f1e833d5ef6\n", | |
| "Validation: 5c164b8c-2ed2-4840-bd1f-fb9a151615d7\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [], | |
| "metadata": { | |
| "id": "bTqsqPLL7pm1" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 9: Initialize a W&B Experiment for Tracking" | |
| ], | |
| "metadata": { | |
| "id": "BYMX7Czf8CfE" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Initialize Weights & Biases for experiment tracking\n", | |
| "def setup_wandb_experiment(project_name: str = \"entity_resolver\") -> bool:\n", | |
| " \"\"\"Setup W&B experiment tracking.\"\"\"\n", | |
| " try:\n", | |
| " if os.environ('WANDB_API_KEY'):\n", | |
| " wandb.login(key=os.environ('WANDB_API_KEY'))\n", | |
| "\n", | |
| " wandb.init(\n", | |
| " project=project_name,\n", | |
| " name=f\"mistral-entity-classifier-2025-07-02\",\n", | |
| " config={\n", | |
| " \"model\": \"ministral-3b-latest\",\n", | |
| " \"training_steps\": 250,\n", | |
| " \"learning_rate\": 0.00007,\n", | |
| " \"dataset_size\": 2031,\n", | |
| " \"multi_label\": True,\n", | |
| " \"random_seed\": RANDOM_SEED\n", | |
| " },\n", | |
| " tags=[\"mistral\", \"entity-resolution\", \"multilabel\", \"taxonomy\"]\n", | |
| " )\n", | |
| "\n", | |
| " print(\"✅ Weights & Biases experiment initialized\")\n", | |
| " return True\n", | |
| "\n", | |
| " except Exception as e:\n", | |
| " print(f\"⚠️ W&B setup failed: {e}\")\n", | |
| " print(\" Continuing without W&B tracking...\")\n", | |
| " return False\n", | |
| "\n", | |
| "# Setup W&B (optional)\n", | |
| "wandb_enabled = setup_wandb_experiment() if os.environ('WANDB_API_KEY') else False" | |
| ], | |
| "metadata": { | |
| "id": "E8HVGoaga9vv" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 10: Create the Fine-Tuning Job with the Uploaded Data Files\n", | |
| "\n", | |
| "We create the job with `auto_start=True` to begin immmediately; `auto_start=False` would allow us to evaluate the estimated cost of the job first and then manually initiate it." | |
| ], | |
| "metadata": { | |
| "id": "8lkQmvLr8KSZ" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Create a fine-tuning job\n", | |
| "created_job = client.fine_tuning.jobs.create(\n", | |
| " model=\"ministral-3b-latest\",\n", | |
| " job_type=\"classifier\",\n", | |
| " training_files=[{\"file_id\": training_data.id, \"weight\": 1}],\n", | |
| " validation_files=[validation_data.id],\n", | |
| " hyperparameters={\"training_steps\": 250, \"learning_rate\": 0.00007},\n", | |
| " auto_start=True,\n", | |
| " integrations=[\n", | |
| " {\n", | |
| " \"project\": \"entity_resolver\",\n", | |
| " \"name\": \"mistral-entity-classifier-1751414690\",\n", | |
| " \"api_key\": os.environ('WANDB_API_KEY'),\n", | |
| " }\n", | |
| " ]\n", | |
| ")\n", | |
| "print(json.dumps(created_job.model_dump(), indent=4))" | |
| ], | |
| "metadata": { | |
| "id": "6onuEKhtda1D" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 11: Monitor the Status of the Job\n", | |
| "\n", | |
| "Small jobs should complete quickly, but time to completion depends on resource availability." | |
| ], | |
| "metadata": { | |
| "id": "kMn3kYrM8n7e" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Retrieve the job details\n", | |
| "retrieved_job = client.fine_tuning.jobs.get(job_id=created_job.id)\n", | |
| "print(json.dumps(retrieved_job.model_dump(), indent=4))" | |
| ], | |
| "metadata": { | |
| "id": "Ve0NMF7leEmz" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Step 12: Test the Fine-Tuned Model\n", | |
| "\n", | |
| "We didn't create an explicit test split, so here we just try a few labeled synthetic examples." | |
| ], | |
| "metadata": { | |
| "id": "TIa-XXr08_O_" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Test data with ground truth labels\n", | |
| "test_data = [\n", | |
| " {\n", | |
| " \"text\": \"Title: Quartette für zwei Violinen, Viola, Violoncell\\nSubjects: String quartets--Scores\",\n", | |
| " \"domain\": \"Music, Sound, and Sonic Arts\",\n", | |
| " \"parent_category\": \"Arts, Culture, and Creative Expression\"\n", | |
| " },\n", | |
| " {\n", | |
| " \"text\": \"Title: Strategic management : concepts and cases\\nSubjects: Strategic planning; Management; Business planning\",\n", | |
| " \"domain\": \"Economics, Business, and Finance\",\n", | |
| " \"parent_category\": \"Society, Governance, and Public Life\"\n", | |
| " },\n", | |
| " {\n", | |
| " \"text\": \"Title: Organic chemistry : structure and function\\nSubjects: Chemistry, Organic; Organic compounds--Structure\",\n", | |
| " \"domain\": \"Natural Sciences\",\n", | |
| " \"parent_category\": \"Sciences, Research, and Discovery\"\n", | |
| " },\n", | |
| " {\n", | |
| " \"text\": \"Title: John Wesley's Sunday service of the Methodists\\nSubjects: Methodist Church--Liturgy--Texts\",\n", | |
| " \"domain\": \"Religion, Theology, and Spirituality\",\n", | |
| " \"parent_category\": \"Humanities, Thought, and Interpretation\"\n", | |
| " },\n", | |
| " {\n", | |
| " \"text\": \"Title: Archaeology and photography : the early years, 1868-1880\\nSubjects: Photography in archaeology\",\n", | |
| " \"domain\": \"History, Heritage, and Memory\",\n", | |
| " \"parent_category\": \"Humanities, Thought, and Interpretation\"\n", | |
| " }\n", | |
| "]" | |
| ], | |
| "metadata": { | |
| "id": "e__fK2-airJP" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "The classifier was trained on multi-label, multi-target data. Uncomment the `json.dumps` line below to view the raw output, which displays a ranked score for each class." | |
| ], | |
| "metadata": { | |
| "id": "5zj1a28V9gY-" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def classify_text(text, model_id):\n", | |
| " try:\n", | |
| " response = client.classifiers.classify(model=model_id, inputs=[text])\n", | |
| " data = response.model_dump()\n", | |
| "\n", | |
| " #print(json.dumps(data, indent=4))\n", | |
| "\n", | |
| " # Extract highest scoring predictions\n", | |
| " domain_scores = data[\"results\"][0][\"domain\"][\"scores\"]\n", | |
| " parent_scores = data[\"results\"][0][\"parent_category\"][\"scores\"]\n", | |
| "\n", | |
| " pred_domain = max(domain_scores, key=domain_scores.get)\n", | |
| " pred_parent = max(parent_scores, key=parent_scores.get)\n", | |
| "\n", | |
| " return pred_domain, pred_parent\n", | |
| " except Exception as e:\n", | |
| " print(f\"Error: {e}\")\n", | |
| " return None, None" | |
| ], | |
| "metadata": { | |
| "id": "791E30WJisYP" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Compare the classifier results to the ground-truth examples to evaluate accuracy." | |
| ], | |
| "metadata": { | |
| "id": "P0SCYkyh-AJ5" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def evaluate_classifier(test_data, model_id):\n", | |
| " results = []\n", | |
| "\n", | |
| " for i, item in enumerate(test_data, 1):\n", | |
| " pred_domain, pred_parent = classify_text(item[\"text\"], model_id)\n", | |
| "\n", | |
| " domain_pass = item[\"domain\"] == pred_domain\n", | |
| " parent_pass = item[\"parent_category\"] == pred_parent\n", | |
| "\n", | |
| " results.append({\n", | |
| " 'test_id': i,\n", | |
| " 'domain_result': 'PASS' if domain_pass else 'FAIL',\n", | |
| " 'parent_result': 'PASS' if parent_pass else 'FAIL',\n", | |
| " 'pred_domain': pred_domain,\n", | |
| " 'pred_parent': pred_parent\n", | |
| " })\n", | |
| "\n", | |
| " print(f\"Test {i}: Domain {results[-1]['domain_result']}, Parent {results[-1]['parent_result']}\")\n", | |
| " if not domain_pass:\n", | |
| " print(f\" Expected: {item['domain']}\")\n", | |
| " print(f\" Got: {pred_domain}\")\n", | |
| " if not parent_pass:\n", | |
| " print(f\" Expected: {item['parent_category']}\")\n", | |
| " print(f\" Got: {pred_parent}\")\n", | |
| "\n", | |
| " return results" | |
| ], | |
| "metadata": { | |
| "id": "iqsibSpMlDWW" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Output the evaluation results." | |
| ], | |
| "metadata": { | |
| "id": "40ix8VG--5gp" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "model_id = os.environ.get('MISTRAL_CLASSIFIER')\n", | |
| "\n", | |
| "if model_id:\n", | |
| " results = evaluate_classifier(test_data, model_id)\n", | |
| "\n", | |
| " domain_passes = sum(1 for r in results if r['domain_result'] == 'PASS')\n", | |
| " parent_passes = sum(1 for r in results if r['parent_result'] == 'PASS')\n", | |
| " total = len(results)\n", | |
| "\n", | |
| " print(f\"\\nFinal Results:\")\n", | |
| " print(f\"Domain: {domain_passes}/{total} PASS ({domain_passes/total:.1%})\")\n", | |
| " print(f\"Parent: {parent_passes}/{total} PASS ({parent_passes/total:.1%})\")\n", | |
| "\n", | |
| "else:\n", | |
| " print(\"No model ID found\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "0rsQPlIblFV2", | |
| "outputId": "c2321646-4d4c-448e-ccba-edcd03f2ca3c" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Test 1: Domain PASS, Parent PASS\n", | |
| "Test 2: Domain PASS, Parent PASS\n", | |
| "Test 3: Domain PASS, Parent PASS\n", | |
| "Test 4: Domain PASS, Parent PASS\n", | |
| "Test 5: Domain FAIL, Parent FAIL\n", | |
| " Expected: History, Heritage, and Memory\n", | |
| " Got: Visual Arts and Design\n", | |
| " Expected: Humanities, Thought, and Interpretation\n", | |
| " Got: Arts, Culture, and Creative Expression\n", | |
| "\n", | |
| "Final Results:\n", | |
| "Domain: 4/5 PASS (80.0%)\n", | |
| "Parent: 4/5 PASS (80.0%)\n" | |
| ] | |
| } | |
| ] | |
| } | |
| ] | |
| } | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment