Skip to content

Instantly share code, notes, and snippets.

@chetanambi
Created February 3, 2021 16:19
Show Gist options
  • Save chetanambi/c70e13f4c12f146df5ba8abec67b0855 to your computer and use it in GitHub Desktop.
Save chetanambi/c70e13f4c12f146df5ba8abec67b0855 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "tf2onnx.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NLURMaC6SNC8",
"outputId": "5fde0e77-da14-4be1-fb60-2a191e9f8d96"
},
"source": [
"!pip install -U tf2onnx onnxruntime"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting tf2onnx\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/42/c8/b082ab508ca201ba830de2f4f3aa42fc862008236825ef69bd10705c3cf2/tf2onnx-1.8.2-py3-none-any.whl (326kB)\n",
"\u001b[K |████████████████████████████████| 327kB 8.7MB/s \n",
"\u001b[?25hCollecting onnxruntime\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/b3/a9/f009251fd1b91a2e1ce6f22d4b5be9936fbd0072842c5087a2a49706c509/onnxruntime-1.6.0-cp36-cp36m-manylinux2014_x86_64.whl (4.1MB)\n",
"\u001b[K |████████████████████████████████| 4.1MB 16.3MB/s \n",
"\u001b[?25hRequirement already satisfied, skipping upgrade: flatbuffers in /usr/local/lib/python3.6/dist-packages (from tf2onnx) (1.12)\n",
"Collecting onnx>=1.4.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/f1/db/608877fea324c3a44aaa50dbcb23ff5b7e3d222a7c5511c19d1651db512e/onnx-1.8.1-cp36-cp36m-manylinux2010_x86_64.whl (14.5MB)\n",
"\u001b[K |████████████████████████████████| 14.5MB 269kB/s \n",
"\u001b[?25hRequirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from tf2onnx) (1.15.0)\n",
"Requirement already satisfied, skipping upgrade: numpy>=1.14.1 in /usr/local/lib/python3.6/dist-packages (from tf2onnx) (1.19.5)\n",
"Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.6/dist-packages (from tf2onnx) (2.23.0)\n",
"Requirement already satisfied, skipping upgrade: protobuf in /usr/local/lib/python3.6/dist-packages (from onnxruntime) (3.12.4)\n",
"Requirement already satisfied, skipping upgrade: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.6/dist-packages (from onnx>=1.4.1->tf2onnx) (3.7.4.3)\n",
"Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->tf2onnx) (2020.12.5)\n",
"Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->tf2onnx) (3.0.4)\n",
"Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->tf2onnx) (2.10)\n",
"Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->tf2onnx) (1.24.3)\n",
"Requirement already satisfied, skipping upgrade: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf->onnxruntime) (53.0.0)\n",
"Installing collected packages: onnx, tf2onnx, onnxruntime\n",
"Successfully installed onnx-1.8.1 onnxruntime-1.6.0 tf2onnx-1.8.2\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XvW40LpQPiJk"
},
"source": [
"# Step 1: Build the model using TensorFlow"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d3hnBhfUMcD8",
"outputId": "4d268d93-96d0-4119-9e6e-74b12df02ecb"
},
"source": [
"import time\r\n",
"import numpy as np\r\n",
"np.set_printoptions(suppress=True)\r\n",
"import tensorflow as tf\r\n",
"\r\n",
"mnist = tf.keras.datasets.mnist\r\n",
"\r\n",
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\r\n",
"\r\n",
"x_train, x_test = x_train / 255.0, x_test / 255.0\r\n",
"\r\n",
"model = tf.keras.models.Sequential([\r\n",
" tf.keras.layers.Flatten(input_shape=(28, 28)),\r\n",
" tf.keras.layers.Dense(128, activation='relu'),\r\n",
" tf.keras.layers.Dropout(0.2),\r\n",
" tf.keras.layers.Dense(10)\r\n",
"])\r\n",
"\r\n",
"model.compile(optimizer='adam',\r\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\r\n",
" metrics=['accuracy'])\r\n",
"\r\n",
"model.fit(x_train, y_train, epochs=5)\r\n",
"\r\n",
"model.evaluate(x_test, y_test, verbose=2)"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
"11493376/11490434 [==============================] - 0s 0us/step\n",
"Epoch 1/5\n",
"1875/1875 [==============================] - 4s 2ms/step - loss: 0.4778 - accuracy: 0.8604\n",
"Epoch 2/5\n",
"1875/1875 [==============================] - 4s 2ms/step - loss: 0.1480 - accuracy: 0.9568\n",
"Epoch 3/5\n",
"1875/1875 [==============================] - 4s 2ms/step - loss: 0.1084 - accuracy: 0.9672\n",
"Epoch 4/5\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.0894 - accuracy: 0.9727\n",
"Epoch 5/5\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.0741 - accuracy: 0.9775\n",
"313/313 - 0s - loss: 0.0733 - accuracy: 0.9757\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[0.07328364253044128, 0.9757000207901001]"
]
},
"metadata": {
"tags": []
},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zEqLftGPM734",
"outputId": "e3033354-1112-4c19-eb6f-00fda740d0c4"
},
"source": [
"probability_model = tf.keras.Sequential([model, \r\n",
" tf.keras.layers.Softmax()])\r\n",
"\r\n",
"start_time = time.time()\r\n",
"tf_predictions = probability_model.predict(x_test)\r\n",
"print(\"Time taken by TensorFlow model: \", time.time() - start_time)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"Time taken by TensorFlow model: 0.3348524570465088\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QNqFONhpSwAG",
"outputId": "621c2bcc-1633-4f8a-9cd3-e38ede433d3a"
},
"source": [
"tf.saved_model.save(model,'/content')"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: /content/assets\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Lrw4y1PSQB0W"
},
"source": [
"# Step 2: Convert the model from TensorFlow to ONNX format"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pSkWynD1TuoY",
"outputId": "d74c7e0c-2b91-4157-b702-f12ea368455c"
},
"source": [
"!python -m tf2onnx.convert --saved-model '/content' --opset 10 --output model.onnx"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"2021-02-03 16:09:51.040012: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1\n",
"2021-02-03 16:09:52,803 - WARNING - '--tag' not specified for saved_model. Using --tag serve\n",
"2021-02-03 16:09:52.816190: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
"2021-02-03 16:09:52.817560: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1\n",
"2021-02-03 16:09:52.827665: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
"2021-02-03 16:09:52.827704: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (7ceb103dd6ee): /proc/driver/nvidia/version does not exist\n",
"2021-02-03 16:09:52.827954: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX512F\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2021-02-03 16:09:52.828120: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
"2021-02-03 16:09:53,088 - INFO - Signatures found in model: [serving_default].\n",
"2021-02-03 16:09:53,088 - WARNING - '--signature_def' not specified, using first signature: serving_default\n",
"2021-02-03 16:09:53.089956: I tensorflow/core/grappler/devices.cc:69] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0\n",
"2021-02-03 16:09:53.090164: I tensorflow/core/grappler/clusters/single_machine.cc:356] Starting new session\n",
"2021-02-03 16:09:53.090479: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
"2021-02-03 16:09:53.090731: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2000185000 Hz\n",
"2021-02-03 16:09:53.092599: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:928] Optimization results for grappler item: graph_to_optimize\n",
" function_optimizer: Graph size after: 26 nodes (19), 33 edges (26), time = 0.693ms.\n",
" function_optimizer: function_optimizer did nothing. time = 0.017ms.\n",
"\n",
"2021-02-03 16:09:53.125417: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tf2onnx/tf_loader.py:529: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
"2021-02-03 16:09:53,130 - WARNING - From /usr/local/lib/python3.6/dist-packages/tf2onnx/tf_loader.py:529: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
"2021-02-03 16:09:53.131296: I tensorflow/core/grappler/devices.cc:69] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0\n",
"2021-02-03 16:09:53.131456: I tensorflow/core/grappler/clusters/single_machine.cc:356] Starting new session\n",
"2021-02-03 16:09:53.131740: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
"2021-02-03 16:09:53.141451: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:928] Optimization results for grappler item: graph_to_optimize\n",
" constant_folding: Graph size after: 18 nodes (-8), 25 edges (-8), time = 7.479ms.\n",
" function_optimizer: function_optimizer did nothing. time = 0.018ms.\n",
" constant_folding: Graph size after: 18 nodes (0), 25 edges (0), time = 0.65ms.\n",
" function_optimizer: function_optimizer did nothing. time = 0.014ms.\n",
"\n",
"2021-02-03 16:09:53.149204: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
"2021-02-03 16:09:53,149 - INFO - Using tensorflow=2.4.1, onnx=1.8.1, tf2onnx=1.8.2/05babf\n",
"2021-02-03 16:09:53,149 - INFO - Using opset <onnx, 10>\n",
"2021-02-03 16:09:53,162 - INFO - Computed 0 values for constant folding\n",
"2021-02-03 16:09:53,200 - INFO - Optimizing ONNX model\n",
"2021-02-03 16:09:53,219 - INFO - After optimization: Cast -1 (1->0), Identity -6 (6->0)\n",
"2021-02-03 16:09:53,220 - INFO - \n",
"2021-02-03 16:09:53,220 - INFO - Successfully converted TensorFlow model /content to ONNX\n",
"2021-02-03 16:09:53,220 - INFO - ONNX model is saved at model.onnx\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F9MEJ1DZQIzf"
},
"source": [
"## Load the model from the disk"
]
},
{
"cell_type": "code",
"metadata": {
"id": "G6n5paWxTqW0"
},
"source": [
"import onnx\r\n",
"onnx_model = onnx.load(\"model.onnx\")"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "pMdAmQTuQL3j"
},
"source": [
"# Step 3: Inference/Prediction using ONNX format"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LfQM8HCqVT4E",
"outputId": "b2ccc227-83d9-4047-cc40-a53f8a5ab880"
},
"source": [
"import onnxruntime as rt\r\n",
"import numpy as np\r\n",
"\r\n",
"model = ('/content/model.onnx')\r\n",
"start_time = time.time()\r\n",
"session = rt.InferenceSession(model)\r\n",
"input_name = session.get_inputs()[0].name\r\n",
"label_name = session.get_outputs()[0].name\r\n",
"onnx_predictions = session.run([label_name], {input_name: x_test.astype(np.float32)})[0]\r\n",
"print(\"Time taken by ONNX: \", time.time() - start_time)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"Time taken by ONNX: 0.03946185111999512\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j0yVA9pgQOah"
},
"source": [
"# Compare TensorFlow and ONNX results"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GTVXbWA9WMYe",
"outputId": "605dfd34-e761-448a-e8c5-3259e885277b"
},
"source": [
"tf_predictions[0]"
],
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0.00000015, 0. , 0.00000046, 0.0000395 , 0. ,\n",
" 0.00000009, 0. , 0.999959 , 0.00000029, 0.00000058],\n",
" dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LWtu2nKuBFFC",
"outputId": "a66186f9-efc0-4796-fb61-c101180a0dfb"
},
"source": [
"tf.nn.softmax(onnx_predictions[0])"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: shape=(10,), dtype=float32, numpy=\n",
"array([0.00000015, 0. , 0.00000046, 0.0000395 , 0. ,\n",
" 0.00000009, 0. , 0.999959 , 0.00000029, 0.00000058],\n",
" dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment