Last active
June 23, 2023 14:26
-
-
Save danibene/74a8986a20acae78d883c893878b6366 to your computer and use it in GitHub Desktop.
skl2onnx_issue_1001.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"name": "skl2onnx_issue_1001.ipynb", | |
"authorship_tag": "ABX9TyPLmmHpcFlnxBi5ZHYuAj+3", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/danibene/74a8986a20acae78d883c893878b6366/skl2onnx_issue_1000.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install skl2onnx" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "pIJOU3Jah4-A", | |
"outputId": "0d569801-65dd-4782-d264-18f73f7a5430" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Collecting skl2onnx\n", | |
" Downloading skl2onnx-1.14.1-py2.py3-none-any.whl (292 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m292.3/292.3 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting onnx>=1.2.1 (from skl2onnx)\n", | |
" Downloading onnx-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.6 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m74.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: scikit-learn<1.3,>=0.19 in /usr/local/lib/python3.10/dist-packages (from skl2onnx) (1.2.2)\n", | |
"Collecting onnxconverter-common>=1.7.0 (from skl2onnx)\n", | |
" Downloading onnxconverter_common-1.13.0-py2.py3-none-any.whl (83 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.8/83.8 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx) (1.22.4)\n", | |
"Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx) (3.20.3)\n", | |
"Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx) (4.5.0)\n", | |
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxconverter-common>=1.7.0->skl2onnx) (23.1)\n", | |
"Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn<1.3,>=0.19->skl2onnx) (1.10.1)\n", | |
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn<1.3,>=0.19->skl2onnx) (1.2.0)\n", | |
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn<1.3,>=0.19->skl2onnx) (3.1.0)\n", | |
"Installing collected packages: onnx, onnxconverter-common, skl2onnx\n", | |
"Successfully installed onnx-1.14.0 onnxconverter-common-1.13.0 skl2onnx-1.14.1\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from pathlib import Path\n", | |
"\n", | |
"import numpy as np\n", | |
"import onnx\n", | |
"from skl2onnx import convert_sklearn\n", | |
"from skl2onnx.common.data_types import FloatTensorType\n", | |
"from sklearn.datasets import load_iris\n", | |
"from sklearn.ensemble import RandomForestClassifier" | |
], | |
"metadata": { | |
"id": "yNHVo62uh0Fj", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "c571b0ce-f961-4b5b-dadb-dbafc0f2450e" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.10/dist-packages/skl2onnx/algebra/onnx_ops.py:159: UserWarning: OpSchema.FormalParameter.typeStr is deprecated and will be removed in 1.16. Use OpSchema.FormalParameter.type_str instead.\n", | |
" tys = obj.typeStr or ''\n", | |
"/usr/local/lib/python3.10/dist-packages/skl2onnx/algebra/automation.py:154: UserWarning: OpSchema.FormalParameter.isHomogeneous is deprecated and will be removed in 1.16. Use OpSchema.FormalParameter.is_homogeneous instead.\n", | |
" if getattr(obj, 'isHomogeneous', False):\n", | |
"/usr/local/lib/python3.10/dist-packages/jinja2/environment.py:485: UserWarning: OpSchema.FormalParameter.typeStr is deprecated and will be removed in 1.16. Use OpSchema.FormalParameter.type_str instead.\n", | |
" return getattr(obj, attribute)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"RANDOM_SEED = 42\n", | |
"TARGET_OPSET = 12\n", | |
"EXPORT_PARENT_PATH = Path(\"exported_to_onnx\")\n", | |
"\n", | |
"\n", | |
"def export_sample_onnx(iteration=1):\n", | |
" # Load the dataset\n", | |
" iris = load_iris()\n", | |
" X = iris.data\n", | |
" y = iris.target\n", | |
" clf = RandomForestClassifier(random_state=RANDOM_SEED)\n", | |
" clf.fit(X, y)\n", | |
" initial_type = [(\"float_input\", FloatTensorType([1, len(X)]))]\n", | |
" converted = convert_sklearn(clf, initial_types=initial_type, target_opset=TARGET_OPSET)\n", | |
" export_path = Path(EXPORT_PARENT_PATH, \"model_\" + str(iteration) + \".onnx\")\n", | |
" Path(export_path).parent.mkdir(parents=True, exist_ok=True)\n", | |
" with open(export_path, \"wb\") as f:\n", | |
" f.write(converted.SerializeToString())\n", | |
"\n", | |
"\n", | |
"def compare_onnx_graphs(model_1_onnx_path: Path, model_2_onnx_path: Path):\n", | |
" # Load the ONNX models\n", | |
" model_1 = onnx.load(str(model_1_onnx_path))\n", | |
" model_2 = onnx.load(str(model_2_onnx_path))\n", | |
"\n", | |
" # Get the graphs from the models\n", | |
" graph_1 = model_1.graph\n", | |
" graph_2 = model_2.graph\n", | |
"\n", | |
" diff_graphs = []\n", | |
"\n", | |
" # Compare the number of nodes\n", | |
" if len(graph_1.node) != len(graph_2.node):\n", | |
" print(\"The number of nodes in the graphs is different.\")\n", | |
" return [{\"graph_1\": node, \"graph_2\": \"not_same_len\"} for node in graph_1.node]\n", | |
"\n", | |
" # Compare each node in the graphs\n", | |
" for node_1, node_2 in zip(graph_1.node, graph_2.node):\n", | |
" if not _check_same_onnx_nodes(node_1, node_2):\n", | |
" diff_graphs.append({\"graph_1\": node_1, \"graph_2\": node_2})\n", | |
"\n", | |
" return diff_graphs\n", | |
"\n", | |
"\n", | |
"def _check_same_onnx_nodes(node_1: onnx.NodeProto, node_2: onnx.NodeProto):\n", | |
" if node_1.attribute != node_2.attribute:\n", | |
" return False\n", | |
" elif node_1.input != node_2.input:\n", | |
" return False\n", | |
" elif node_1.output != node_2.output:\n", | |
" return False\n", | |
" elif node_1.op_type != node_2.op_type:\n", | |
" return False\n", | |
" else:\n", | |
" return True" | |
], | |
"metadata": { | |
"id": "JVRhoZegiGhv" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"id": "umNNhhKHhqTa" | |
}, | |
"outputs": [], | |
"source": [ | |
"for i in range(5):\n", | |
" export_sample_onnx(i)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"onnx_file_paths = list(EXPORT_PARENT_PATH.rglob(\"*.onnx\"))\n", | |
"\n", | |
"for i in range(len(onnx_file_paths) - 1):\n", | |
" diffs_graphs = compare_onnx_graphs(onnx_file_paths[i], onnx_file_paths[i + 1])\n", | |
" if len(diffs_graphs) > 0:\n", | |
" print(diffs_graphs)" | |
], | |
"metadata": { | |
"id": "FKSCk1zWiBSF", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "a3694b20-7020-47c1-ae40-a7e2164a1105" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"[{'graph_1': input: \"label\"\n", | |
"output: \"output_label\"\n", | |
"name: \"Cast\"\n", | |
"op_type: \"Cast\"\n", | |
"attribute {\n", | |
" name: \"to\"\n", | |
" i: 7\n", | |
" type: INT\n", | |
"}\n", | |
"domain: \"\"\n", | |
", 'graph_2': input: \"probabilities\"\n", | |
"output: \"output_probability\"\n", | |
"name: \"ZipMap\"\n", | |
"op_type: \"ZipMap\"\n", | |
"attribute {\n", | |
" name: \"classlabels_int64s\"\n", | |
" ints: 0\n", | |
" ints: 1\n", | |
" ints: 2\n", | |
" type: INTS\n", | |
"}\n", | |
"domain: \"ai.onnx.ml\"\n", | |
"}, {'graph_1': input: \"probabilities\"\n", | |
"output: \"output_probability\"\n", | |
"name: \"ZipMap\"\n", | |
"op_type: \"ZipMap\"\n", | |
"attribute {\n", | |
" name: \"classlabels_int64s\"\n", | |
" ints: 0\n", | |
" ints: 1\n", | |
" ints: 2\n", | |
" type: INTS\n", | |
"}\n", | |
"domain: \"ai.onnx.ml\"\n", | |
", 'graph_2': input: \"label\"\n", | |
"output: \"output_label\"\n", | |
"name: \"Cast\"\n", | |
"op_type: \"Cast\"\n", | |
"attribute {\n", | |
" name: \"to\"\n", | |
" i: 7\n", | |
" type: INT\n", | |
"}\n", | |
"domain: \"\"\n", | |
"}]\n", | |
"[{'graph_1': input: \"probabilities\"\n", | |
"output: \"output_probability\"\n", | |
"name: \"ZipMap\"\n", | |
"op_type: \"ZipMap\"\n", | |
"attribute {\n", | |
" name: \"classlabels_int64s\"\n", | |
" ints: 0\n", | |
" ints: 1\n", | |
" ints: 2\n", | |
" type: INTS\n", | |
"}\n", | |
"domain: \"ai.onnx.ml\"\n", | |
", 'graph_2': input: \"label\"\n", | |
"output: \"output_label\"\n", | |
"name: \"Cast\"\n", | |
"op_type: \"Cast\"\n", | |
"attribute {\n", | |
" name: \"to\"\n", | |
" i: 7\n", | |
" type: INT\n", | |
"}\n", | |
"domain: \"\"\n", | |
"}, {'graph_1': input: \"label\"\n", | |
"output: \"output_label\"\n", | |
"name: \"Cast\"\n", | |
"op_type: \"Cast\"\n", | |
"attribute {\n", | |
" name: \"to\"\n", | |
" i: 7\n", | |
" type: INT\n", | |
"}\n", | |
"domain: \"\"\n", | |
", 'graph_2': input: \"probabilities\"\n", | |
"output: \"output_probability\"\n", | |
"name: \"ZipMap\"\n", | |
"op_type: \"ZipMap\"\n", | |
"attribute {\n", | |
" name: \"classlabels_int64s\"\n", | |
" ints: 0\n", | |
" ints: 1\n", | |
" ints: 2\n", | |
" type: INTS\n", | |
"}\n", | |
"domain: \"ai.onnx.ml\"\n", | |
"}]\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "hTwMnDLXkp7o" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment