Created
August 2, 2023 17:34
-
-
Save danibene/b2d5ec03a6c259246e590cbf1dda4e93 to your computer and use it in GitHub Desktop.
skl2onnx_pr_1004.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/danibene/b2d5ec03a6c259246e590cbf1dda4e93/skl2onnx_pr_1004.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install git+https://github.com/onnx/sklearn-onnx.git" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "pIJOU3Jah4-A", | |
"outputId": "5f9424c0-035d-45ef-8515-9be5babc5fd0" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Collecting git+https://github.com/onnx/sklearn-onnx.git\n", | |
" Cloning https://github.com/onnx/sklearn-onnx.git to /tmp/pip-req-build-ipyc08zc\n", | |
" Running command git clone --filter=blob:none --quiet https://github.com/onnx/sklearn-onnx.git /tmp/pip-req-build-ipyc08zc\n", | |
" Resolved https://github.com/onnx/sklearn-onnx.git to commit 3ef5e136ca2c02b7aae03d6879265a5691c661e2\n", | |
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", | |
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", | |
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", | |
"Requirement already satisfied: onnx>=1.2.1 in /usr/local/lib/python3.10/dist-packages (from skl2onnx==1.14.1) (1.14.0)\n", | |
"Requirement already satisfied: scikit-learn>=0.19 in /usr/local/lib/python3.10/dist-packages (from skl2onnx==1.14.1) (1.2.2)\n", | |
"Requirement already satisfied: onnxconverter-common>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from skl2onnx==1.14.1) (1.13.0)\n", | |
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx==1.14.1) (1.22.4)\n", | |
"Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx==1.14.1) (3.20.3)\n", | |
"Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.10/dist-packages (from onnx>=1.2.1->skl2onnx==1.14.1) (4.7.1)\n", | |
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxconverter-common>=1.7.0->skl2onnx==1.14.1) (23.1)\n", | |
"Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.19->skl2onnx==1.14.1) (1.10.1)\n", | |
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.19->skl2onnx==1.14.1) (1.3.1)\n", | |
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.19->skl2onnx==1.14.1) (3.2.0)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from pathlib import Path\n", | |
"\n", | |
"import numpy as np\n", | |
"import onnx\n", | |
"from skl2onnx import convert_sklearn\n", | |
"from skl2onnx.common.data_types import FloatTensorType\n", | |
"from sklearn.datasets import load_iris\n", | |
"from sklearn.ensemble import RandomForestClassifier" | |
], | |
"metadata": { | |
"id": "yNHVo62uh0Fj" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"RANDOM_SEED = 42\n", | |
"TARGET_OPSET = 12\n", | |
"EXPORT_PARENT_PATH = Path(\"exported_to_onnx\")\n", | |
"\n", | |
"\n", | |
"def export_sample_onnx(iteration=1):\n", | |
" # Load the dataset\n", | |
" iris = load_iris()\n", | |
" X = iris.data\n", | |
" y = iris.target\n", | |
" clf = RandomForestClassifier(random_state=RANDOM_SEED)\n", | |
" clf.fit(X, y)\n", | |
" initial_type = [(\"float_input\", FloatTensorType([1, len(X)]))]\n", | |
" converted = convert_sklearn(clf, initial_types=initial_type, target_opset=TARGET_OPSET)\n", | |
" export_path = Path(EXPORT_PARENT_PATH, \"model_\" + str(iteration) + \".onnx\")\n", | |
" Path(export_path).parent.mkdir(parents=True, exist_ok=True)\n", | |
" with open(export_path, \"wb\") as f:\n", | |
" f.write(converted.SerializeToString())\n", | |
"\n", | |
"\n", | |
"def compare_onnx_graphs(model_1_onnx_path: Path, model_2_onnx_path: Path):\n", | |
" # Load the ONNX models\n", | |
" model_1 = onnx.load(str(model_1_onnx_path))\n", | |
" model_2 = onnx.load(str(model_2_onnx_path))\n", | |
"\n", | |
" # Get the graphs from the models\n", | |
" graph_1 = model_1.graph\n", | |
" graph_2 = model_2.graph\n", | |
"\n", | |
" diff_graphs = []\n", | |
"\n", | |
" # Compare the number of nodes\n", | |
" if len(graph_1.node) != len(graph_2.node):\n", | |
" print(\"The number of nodes in the graphs is different.\")\n", | |
" return [{\"graph_1\": node, \"graph_2\": \"not_same_len\"} for node in graph_1.node]\n", | |
"\n", | |
" # Compare each node in the graphs\n", | |
" for node_1, node_2 in zip(graph_1.node, graph_2.node):\n", | |
" if not _check_same_onnx_nodes(node_1, node_2):\n", | |
" diff_graphs.append({\"graph_1\": node_1, \"graph_2\": node_2})\n", | |
"\n", | |
" return diff_graphs\n", | |
"\n", | |
"\n", | |
"def _check_same_onnx_nodes(node_1: onnx.NodeProto, node_2: onnx.NodeProto):\n", | |
" if node_1.attribute != node_2.attribute:\n", | |
" return False\n", | |
" elif node_1.input != node_2.input:\n", | |
" return False\n", | |
" elif node_1.output != node_2.output:\n", | |
" return False\n", | |
" elif node_1.op_type != node_2.op_type:\n", | |
" return False\n", | |
" else:\n", | |
" return True" | |
], | |
"metadata": { | |
"id": "JVRhoZegiGhv" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"id": "umNNhhKHhqTa" | |
}, | |
"outputs": [], | |
"source": [ | |
"for i in range(5):\n", | |
" export_sample_onnx(i)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"onnx_file_paths = list(EXPORT_PARENT_PATH.rglob(\"*.onnx\"))\n", | |
"\n", | |
"for i in range(len(onnx_file_paths) - 1):\n", | |
" diffs_graphs = compare_onnx_graphs(onnx_file_paths[i], onnx_file_paths[i + 1])\n", | |
" if len(diffs_graphs) > 0:\n", | |
" print(diffs_graphs)" | |
], | |
"metadata": { | |
"id": "FKSCk1zWiBSF", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "bd6a8fa4-45c7-4199-ab51-0f41f5b19294" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"[{'graph_1': input: \"label\"\n", | |
"output: \"output_label\"\n", | |
"name: \"Cast\"\n", | |
"op_type: \"Cast\"\n", | |
"attribute {\n", | |
" name: \"to\"\n", | |
" i: 7\n", | |
" type: INT\n", | |
"}\n", | |
"domain: \"\"\n", | |
", 'graph_2': input: \"probabilities\"\n", | |
"output: \"output_probability\"\n", | |
"name: \"ZipMap\"\n", | |
"op_type: \"ZipMap\"\n", | |
"attribute {\n", | |
" name: \"classlabels_int64s\"\n", | |
" ints: 0\n", | |
" ints: 1\n", | |
" ints: 2\n", | |
" type: INTS\n", | |
"}\n", | |
"domain: \"ai.onnx.ml\"\n", | |
"}, {'graph_1': input: \"probabilities\"\n", | |
"output: \"output_probability\"\n", | |
"name: \"ZipMap\"\n", | |
"op_type: \"ZipMap\"\n", | |
"attribute {\n", | |
" name: \"classlabels_int64s\"\n", | |
" ints: 0\n", | |
" ints: 1\n", | |
" ints: 2\n", | |
" type: INTS\n", | |
"}\n", | |
"domain: \"ai.onnx.ml\"\n", | |
", 'graph_2': input: \"label\"\n", | |
"output: \"output_label\"\n", | |
"name: \"Cast\"\n", | |
"op_type: \"Cast\"\n", | |
"attribute {\n", | |
" name: \"to\"\n", | |
" i: 7\n", | |
" type: INT\n", | |
"}\n", | |
"domain: \"\"\n", | |
"}]\n", | |
"[{'graph_1': input: \"probabilities\"\n", | |
"output: \"output_probability\"\n", | |
"name: \"ZipMap\"\n", | |
"op_type: \"ZipMap\"\n", | |
"attribute {\n", | |
" name: \"classlabels_int64s\"\n", | |
" ints: 0\n", | |
" ints: 1\n", | |
" ints: 2\n", | |
" type: INTS\n", | |
"}\n", | |
"domain: \"ai.onnx.ml\"\n", | |
", 'graph_2': input: \"label\"\n", | |
"output: \"output_label\"\n", | |
"name: \"Cast\"\n", | |
"op_type: \"Cast\"\n", | |
"attribute {\n", | |
" name: \"to\"\n", | |
" i: 7\n", | |
" type: INT\n", | |
"}\n", | |
"domain: \"\"\n", | |
"}, {'graph_1': input: \"label\"\n", | |
"output: \"output_label\"\n", | |
"name: \"Cast\"\n", | |
"op_type: \"Cast\"\n", | |
"attribute {\n", | |
" name: \"to\"\n", | |
" i: 7\n", | |
" type: INT\n", | |
"}\n", | |
"domain: \"\"\n", | |
", 'graph_2': input: \"probabilities\"\n", | |
"output: \"output_probability\"\n", | |
"name: \"ZipMap\"\n", | |
"op_type: \"ZipMap\"\n", | |
"attribute {\n", | |
" name: \"classlabels_int64s\"\n", | |
" ints: 0\n", | |
" ints: 1\n", | |
" ints: 2\n", | |
" type: INTS\n", | |
"}\n", | |
"domain: \"ai.onnx.ml\"\n", | |
"}]\n", | |
"[{'graph_1': input: \"label\"\n", | |
"output: \"output_label\"\n", | |
"name: \"Cast\"\n", | |
"op_type: \"Cast\"\n", | |
"attribute {\n", | |
" name: \"to\"\n", | |
" i: 7\n", | |
" type: INT\n", | |
"}\n", | |
"domain: \"\"\n", | |
", 'graph_2': input: \"probabilities\"\n", | |
"output: \"output_probability\"\n", | |
"name: \"ZipMap\"\n", | |
"op_type: \"ZipMap\"\n", | |
"attribute {\n", | |
" name: \"classlabels_int64s\"\n", | |
" ints: 0\n", | |
" ints: 1\n", | |
" ints: 2\n", | |
" type: INTS\n", | |
"}\n", | |
"domain: \"ai.onnx.ml\"\n", | |
"}, {'graph_1': input: \"probabilities\"\n", | |
"output: \"output_probability\"\n", | |
"name: \"ZipMap\"\n", | |
"op_type: \"ZipMap\"\n", | |
"attribute {\n", | |
" name: \"classlabels_int64s\"\n", | |
" ints: 0\n", | |
" ints: 1\n", | |
" ints: 2\n", | |
" type: INTS\n", | |
"}\n", | |
"domain: \"ai.onnx.ml\"\n", | |
", 'graph_2': input: \"label\"\n", | |
"output: \"output_label\"\n", | |
"name: \"Cast\"\n", | |
"op_type: \"Cast\"\n", | |
"attribute {\n", | |
" name: \"to\"\n", | |
" i: 7\n", | |
" type: INT\n", | |
"}\n", | |
"domain: \"\"\n", | |
"}]\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "hTwMnDLXkp7o" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment