Skip to content

Instantly share code, notes, and snippets.

@danibene
Created August 2, 2023 17:34
Show Gist options
  • Save danibene/b2d5ec03a6c259246e590cbf1dda4e93 to your computer and use it in GitHub Desktop.
Save danibene/b2d5ec03a6c259246e590cbf1dda4e93 to your computer and use it in GitHub Desktop.
skl2onnx_pr_1004.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/danibene/b2d5ec03a6c259246e590cbf1dda4e93/skl2onnx_pr_1004.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install git+https://github.com/onnx/sklearn-onnx.git"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pIJOU3Jah4-A",
"outputId": "5f9424c0-035d-45ef-8515-9be5babc5fd0"
},
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting git+https://github.com/onnx/sklearn-onnx.git\n",
" Cloning https://github.com/onnx/sklearn-onnx.git to /tmp/pip-req-build-ipyc08zc\n",
" Running command git clone --filter=blob:none --quiet https://github.com/onnx/sklearn-onnx.git /tmp/pip-req-build-ipyc08zc\n",
" Resolved https://github.com/onnx/sklearn-onnx.git to commit 3ef5e136ca2c02b7aae03d6879265a5691c661e2\n",
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: onnx>=1.2.1 in /usr/local/lib/python3.10/dist-packages (from skl2onnx==1.14.1) (1.14.0)\n",
"Requirement already satisfied: scikit-learn>=0.19 in /usr/local/lib/python3.10/dist-packages (from skl2onnx==1.14.1) (1.2.2)\n",
"Requirement already satisfied: onnxconverter-common>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from skl2onnx==1.14.1) (1.13.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx==1.14.1) (1.22.4)\n",
"Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx==1.14.1) (3.20.3)\n",
"Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx==1.14.1) (4.7.1)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxconverter-common>=1.7.0->skl2onnx==1.14.1) (23.1)\n",
"Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.19->skl2onnx==1.14.1) (1.10.1)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.19->skl2onnx==1.14.1) (1.3.1)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.19->skl2onnx==1.14.1) (3.2.0)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from pathlib import Path\n",
"\n",
"import numpy as np\n",
"import onnx\n",
"from skl2onnx import convert_sklearn\n",
"from skl2onnx.common.data_types import FloatTensorType\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.ensemble import RandomForestClassifier"
],
"metadata": {
"id": "yNHVo62uh0Fj"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"RANDOM_SEED = 42\n",
"TARGET_OPSET = 12\n",
"EXPORT_PARENT_PATH = Path(\"exported_to_onnx\")\n",
"\n",
"\n",
"def export_sample_onnx(iteration=1):\n",
" # Load the dataset\n",
" iris = load_iris()\n",
" X = iris.data\n",
" y = iris.target\n",
" clf = RandomForestClassifier(random_state=RANDOM_SEED)\n",
" clf.fit(X, y)\n",
" initial_type = [(\"float_input\", FloatTensorType([1, len(X)]))]\n",
" converted = convert_sklearn(clf, initial_types=initial_type, target_opset=TARGET_OPSET)\n",
" export_path = Path(EXPORT_PARENT_PATH, \"model_\" + str(iteration) + \".onnx\")\n",
" Path(export_path).parent.mkdir(parents=True, exist_ok=True)\n",
" with open(export_path, \"wb\") as f:\n",
" f.write(converted.SerializeToString())\n",
"\n",
"\n",
"def compare_onnx_graphs(model_1_onnx_path: Path, model_2_onnx_path: Path):\n",
" # Load the ONNX models\n",
" model_1 = onnx.load(str(model_1_onnx_path))\n",
" model_2 = onnx.load(str(model_2_onnx_path))\n",
"\n",
" # Get the graphs from the models\n",
" graph_1 = model_1.graph\n",
" graph_2 = model_2.graph\n",
"\n",
" diff_graphs = []\n",
"\n",
" # Compare the number of nodes\n",
" if len(graph_1.node) != len(graph_2.node):\n",
" print(\"The number of nodes in the graphs is different.\")\n",
" return [{\"graph_1\": node, \"graph_2\": \"not_same_len\"} for node in graph_1.node]\n",
"\n",
" # Compare each node in the graphs\n",
" for node_1, node_2 in zip(graph_1.node, graph_2.node):\n",
" if not _check_same_onnx_nodes(node_1, node_2):\n",
" diff_graphs.append({\"graph_1\": node_1, \"graph_2\": node_2})\n",
"\n",
" return diff_graphs\n",
"\n",
"\n",
"def _check_same_onnx_nodes(node_1: onnx.NodeProto, node_2: onnx.NodeProto):\n",
" if node_1.attribute != node_2.attribute:\n",
" return False\n",
" elif node_1.input != node_2.input:\n",
" return False\n",
" elif node_1.output != node_2.output:\n",
" return False\n",
" elif node_1.op_type != node_2.op_type:\n",
" return False\n",
" else:\n",
" return True"
],
"metadata": {
"id": "JVRhoZegiGhv"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "umNNhhKHhqTa"
},
"outputs": [],
"source": [
"for i in range(5):\n",
" export_sample_onnx(i)"
]
},
{
"cell_type": "code",
"source": [
"onnx_file_paths = list(EXPORT_PARENT_PATH.rglob(\"*.onnx\"))\n",
"\n",
"for i in range(len(onnx_file_paths) - 1):\n",
" diffs_graphs = compare_onnx_graphs(onnx_file_paths[i], onnx_file_paths[i + 1])\n",
" if len(diffs_graphs) > 0:\n",
" print(diffs_graphs)"
],
"metadata": {
"id": "FKSCk1zWiBSF",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "bd6a8fa4-45c7-4199-ab51-0f41f5b19294"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[{'graph_1': input: \"label\"\n",
"output: \"output_label\"\n",
"name: \"Cast\"\n",
"op_type: \"Cast\"\n",
"attribute {\n",
" name: \"to\"\n",
" i: 7\n",
" type: INT\n",
"}\n",
"domain: \"\"\n",
", 'graph_2': input: \"probabilities\"\n",
"output: \"output_probability\"\n",
"name: \"ZipMap\"\n",
"op_type: \"ZipMap\"\n",
"attribute {\n",
" name: \"classlabels_int64s\"\n",
" ints: 0\n",
" ints: 1\n",
" ints: 2\n",
" type: INTS\n",
"}\n",
"domain: \"ai.onnx.ml\"\n",
"}, {'graph_1': input: \"probabilities\"\n",
"output: \"output_probability\"\n",
"name: \"ZipMap\"\n",
"op_type: \"ZipMap\"\n",
"attribute {\n",
" name: \"classlabels_int64s\"\n",
" ints: 0\n",
" ints: 1\n",
" ints: 2\n",
" type: INTS\n",
"}\n",
"domain: \"ai.onnx.ml\"\n",
", 'graph_2': input: \"label\"\n",
"output: \"output_label\"\n",
"name: \"Cast\"\n",
"op_type: \"Cast\"\n",
"attribute {\n",
" name: \"to\"\n",
" i: 7\n",
" type: INT\n",
"}\n",
"domain: \"\"\n",
"}]\n",
"[{'graph_1': input: \"probabilities\"\n",
"output: \"output_probability\"\n",
"name: \"ZipMap\"\n",
"op_type: \"ZipMap\"\n",
"attribute {\n",
" name: \"classlabels_int64s\"\n",
" ints: 0\n",
" ints: 1\n",
" ints: 2\n",
" type: INTS\n",
"}\n",
"domain: \"ai.onnx.ml\"\n",
", 'graph_2': input: \"label\"\n",
"output: \"output_label\"\n",
"name: \"Cast\"\n",
"op_type: \"Cast\"\n",
"attribute {\n",
" name: \"to\"\n",
" i: 7\n",
" type: INT\n",
"}\n",
"domain: \"\"\n",
"}, {'graph_1': input: \"label\"\n",
"output: \"output_label\"\n",
"name: \"Cast\"\n",
"op_type: \"Cast\"\n",
"attribute {\n",
" name: \"to\"\n",
" i: 7\n",
" type: INT\n",
"}\n",
"domain: \"\"\n",
", 'graph_2': input: \"probabilities\"\n",
"output: \"output_probability\"\n",
"name: \"ZipMap\"\n",
"op_type: \"ZipMap\"\n",
"attribute {\n",
" name: \"classlabels_int64s\"\n",
" ints: 0\n",
" ints: 1\n",
" ints: 2\n",
" type: INTS\n",
"}\n",
"domain: \"ai.onnx.ml\"\n",
"}]\n",
"[{'graph_1': input: \"label\"\n",
"output: \"output_label\"\n",
"name: \"Cast\"\n",
"op_type: \"Cast\"\n",
"attribute {\n",
" name: \"to\"\n",
" i: 7\n",
" type: INT\n",
"}\n",
"domain: \"\"\n",
", 'graph_2': input: \"probabilities\"\n",
"output: \"output_probability\"\n",
"name: \"ZipMap\"\n",
"op_type: \"ZipMap\"\n",
"attribute {\n",
" name: \"classlabels_int64s\"\n",
" ints: 0\n",
" ints: 1\n",
" ints: 2\n",
" type: INTS\n",
"}\n",
"domain: \"ai.onnx.ml\"\n",
"}, {'graph_1': input: \"probabilities\"\n",
"output: \"output_probability\"\n",
"name: \"ZipMap\"\n",
"op_type: \"ZipMap\"\n",
"attribute {\n",
" name: \"classlabels_int64s\"\n",
" ints: 0\n",
" ints: 1\n",
" ints: 2\n",
" type: INTS\n",
"}\n",
"domain: \"ai.onnx.ml\"\n",
", 'graph_2': input: \"label\"\n",
"output: \"output_label\"\n",
"name: \"Cast\"\n",
"op_type: \"Cast\"\n",
"attribute {\n",
" name: \"to\"\n",
" i: 7\n",
" type: INT\n",
"}\n",
"domain: \"\"\n",
"}]\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "hTwMnDLXkp7o"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment