Skip to content

Instantly share code, notes, and snippets.

@danibene
Last active June 23, 2023 14:26
Show Gist options
  • Save danibene/74a8986a20acae78d883c893878b6366 to your computer and use it in GitHub Desktop.
Save danibene/74a8986a20acae78d883c893878b6366 to your computer and use it in GitHub Desktop.
skl2onnx_issue_1001.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"name": "skl2onnx_issue_1001.ipynb",
"authorship_tag": "ABX9TyPLmmHpcFlnxBi5ZHYuAj+3",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/danibene/74a8986a20acae78d883c893878b6366/skl2onnx_issue_1000.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install skl2onnx"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pIJOU3Jah4-A",
"outputId": "0d569801-65dd-4782-d264-18f73f7a5430"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting skl2onnx\n",
" Downloading skl2onnx-1.14.1-py2.py3-none-any.whl (292 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m292.3/292.3 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting onnx>=1.2.1 (from skl2onnx)\n",
" Downloading onnx-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.6 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m74.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: scikit-learn<1.3,>=0.19 in /usr/local/lib/python3.10/dist-packages (from skl2onnx) (1.2.2)\n",
"Collecting onnxconverter-common>=1.7.0 (from skl2onnx)\n",
" Downloading onnxconverter_common-1.13.0-py2.py3-none-any.whl (83 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.8/83.8 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx) (1.22.4)\n",
"Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx) (3.20.3)\n",
"Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx) (4.5.0)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxconverter-common>=1.7.0->skl2onnx) (23.1)\n",
"Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn<1.3,>=0.19->skl2onnx) (1.10.1)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn<1.3,>=0.19->skl2onnx) (1.2.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn<1.3,>=0.19->skl2onnx) (3.1.0)\n",
"Installing collected packages: onnx, onnxconverter-common, skl2onnx\n",
"Successfully installed onnx-1.14.0 onnxconverter-common-1.13.0 skl2onnx-1.14.1\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from pathlib import Path\n",
"\n",
"import numpy as np\n",
"import onnx\n",
"from skl2onnx import convert_sklearn\n",
"from skl2onnx.common.data_types import FloatTensorType\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.ensemble import RandomForestClassifier"
],
"metadata": {
"id": "yNHVo62uh0Fj",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "c571b0ce-f961-4b5b-dadb-dbafc0f2450e"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/skl2onnx/algebra/onnx_ops.py:159: UserWarning: OpSchema.FormalParameter.typeStr is deprecated and will be removed in 1.16. Use OpSchema.FormalParameter.type_str instead.\n",
" tys = obj.typeStr or ''\n",
"/usr/local/lib/python3.10/dist-packages/skl2onnx/algebra/automation.py:154: UserWarning: OpSchema.FormalParameter.isHomogeneous is deprecated and will be removed in 1.16. Use OpSchema.FormalParameter.is_homogeneous instead.\n",
" if getattr(obj, 'isHomogeneous', False):\n",
"/usr/local/lib/python3.10/dist-packages/jinja2/environment.py:485: UserWarning: OpSchema.FormalParameter.typeStr is deprecated and will be removed in 1.16. Use OpSchema.FormalParameter.type_str instead.\n",
" return getattr(obj, attribute)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"RANDOM_SEED = 42\n",
"TARGET_OPSET = 12\n",
"EXPORT_PARENT_PATH = Path(\"exported_to_onnx\")\n",
"\n",
"\n",
"def export_sample_onnx(iteration=1):\n",
" # Load the dataset\n",
" iris = load_iris()\n",
" X = iris.data\n",
" y = iris.target\n",
" clf = RandomForestClassifier(random_state=RANDOM_SEED)\n",
" clf.fit(X, y)\n",
" initial_type = [(\"float_input\", FloatTensorType([1, len(X)]))]\n",
" converted = convert_sklearn(clf, initial_types=initial_type, target_opset=TARGET_OPSET)\n",
" export_path = Path(EXPORT_PARENT_PATH, \"model_\" + str(iteration) + \".onnx\")\n",
" Path(export_path).parent.mkdir(parents=True, exist_ok=True)\n",
" with open(export_path, \"wb\") as f:\n",
" f.write(converted.SerializeToString())\n",
"\n",
"\n",
"def compare_onnx_graphs(model_1_onnx_path: Path, model_2_onnx_path: Path):\n",
" # Load the ONNX models\n",
" model_1 = onnx.load(str(model_1_onnx_path))\n",
" model_2 = onnx.load(str(model_2_onnx_path))\n",
"\n",
" # Get the graphs from the models\n",
" graph_1 = model_1.graph\n",
" graph_2 = model_2.graph\n",
"\n",
" diff_graphs = []\n",
"\n",
" # Compare the number of nodes\n",
" if len(graph_1.node) != len(graph_2.node):\n",
" print(\"The number of nodes in the graphs is different.\")\n",
" return [{\"graph_1\": node, \"graph_2\": \"not_same_len\"} for node in graph_1.node]\n",
"\n",
" # Compare each node in the graphs\n",
" for node_1, node_2 in zip(graph_1.node, graph_2.node):\n",
" if not _check_same_onnx_nodes(node_1, node_2):\n",
" diff_graphs.append({\"graph_1\": node_1, \"graph_2\": node_2})\n",
"\n",
" return diff_graphs\n",
"\n",
"\n",
"def _check_same_onnx_nodes(node_1: onnx.NodeProto, node_2: onnx.NodeProto):\n",
" if node_1.attribute != node_2.attribute:\n",
" return False\n",
" elif node_1.input != node_2.input:\n",
" return False\n",
" elif node_1.output != node_2.output:\n",
" return False\n",
" elif node_1.op_type != node_2.op_type:\n",
" return False\n",
" else:\n",
" return True"
],
"metadata": {
"id": "JVRhoZegiGhv"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "umNNhhKHhqTa"
},
"outputs": [],
"source": [
"for i in range(5):\n",
" export_sample_onnx(i)"
]
},
{
"cell_type": "code",
"source": [
"onnx_file_paths = list(EXPORT_PARENT_PATH.rglob(\"*.onnx\"))\n",
"\n",
"for i in range(len(onnx_file_paths) - 1):\n",
" diffs_graphs = compare_onnx_graphs(onnx_file_paths[i], onnx_file_paths[i + 1])\n",
" if len(diffs_graphs) > 0:\n",
" print(diffs_graphs)"
],
"metadata": {
"id": "FKSCk1zWiBSF",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a3694b20-7020-47c1-ae40-a7e2164a1105"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[{'graph_1': input: \"label\"\n",
"output: \"output_label\"\n",
"name: \"Cast\"\n",
"op_type: \"Cast\"\n",
"attribute {\n",
" name: \"to\"\n",
" i: 7\n",
" type: INT\n",
"}\n",
"domain: \"\"\n",
", 'graph_2': input: \"probabilities\"\n",
"output: \"output_probability\"\n",
"name: \"ZipMap\"\n",
"op_type: \"ZipMap\"\n",
"attribute {\n",
" name: \"classlabels_int64s\"\n",
" ints: 0\n",
" ints: 1\n",
" ints: 2\n",
" type: INTS\n",
"}\n",
"domain: \"ai.onnx.ml\"\n",
"}, {'graph_1': input: \"probabilities\"\n",
"output: \"output_probability\"\n",
"name: \"ZipMap\"\n",
"op_type: \"ZipMap\"\n",
"attribute {\n",
" name: \"classlabels_int64s\"\n",
" ints: 0\n",
" ints: 1\n",
" ints: 2\n",
" type: INTS\n",
"}\n",
"domain: \"ai.onnx.ml\"\n",
", 'graph_2': input: \"label\"\n",
"output: \"output_label\"\n",
"name: \"Cast\"\n",
"op_type: \"Cast\"\n",
"attribute {\n",
" name: \"to\"\n",
" i: 7\n",
" type: INT\n",
"}\n",
"domain: \"\"\n",
"}]\n",
"[{'graph_1': input: \"probabilities\"\n",
"output: \"output_probability\"\n",
"name: \"ZipMap\"\n",
"op_type: \"ZipMap\"\n",
"attribute {\n",
" name: \"classlabels_int64s\"\n",
" ints: 0\n",
" ints: 1\n",
" ints: 2\n",
" type: INTS\n",
"}\n",
"domain: \"ai.onnx.ml\"\n",
", 'graph_2': input: \"label\"\n",
"output: \"output_label\"\n",
"name: \"Cast\"\n",
"op_type: \"Cast\"\n",
"attribute {\n",
" name: \"to\"\n",
" i: 7\n",
" type: INT\n",
"}\n",
"domain: \"\"\n",
"}, {'graph_1': input: \"label\"\n",
"output: \"output_label\"\n",
"name: \"Cast\"\n",
"op_type: \"Cast\"\n",
"attribute {\n",
" name: \"to\"\n",
" i: 7\n",
" type: INT\n",
"}\n",
"domain: \"\"\n",
", 'graph_2': input: \"probabilities\"\n",
"output: \"output_probability\"\n",
"name: \"ZipMap\"\n",
"op_type: \"ZipMap\"\n",
"attribute {\n",
" name: \"classlabels_int64s\"\n",
" ints: 0\n",
" ints: 1\n",
" ints: 2\n",
" type: INTS\n",
"}\n",
"domain: \"ai.onnx.ml\"\n",
"}]\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "hTwMnDLXkp7o"
},
"execution_count": 5,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment