Created
May 9, 2020 11:51
-
-
Save iwatake2222/5bb26123ebf891957109787d4f4dc685 to your computer and use it in GitHub Desktop.
MakeFullUseOfEdgeTPU.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "MakeFullUseOfEdgeTPU.ipynb", | |
"provenance": [], | |
"collapsed_sections": [ | |
"XkL8FXvCl9BV" | |
], | |
"authorship_tag": "ABX9TyPJNGhxARgtJuhqBZQAxPdH", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/iwatake2222/5bb26123ebf891957109787d4f4dc685/makefulluseofedgetpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "lAMuurEvr6Af", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Create very high TOPS tflite model for Edge TPU" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "-2hPvu7ar1SJ", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Setup" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "XEmYPbuAy1Wc", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Install EdgeTPU compiler\n", | |
"!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -\n", | |
"!echo \"deb https://packages.cloud.google.com/apt coral-edgetpu-stable main\" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list\n", | |
"!sudo apt-get update\n", | |
"!sudo apt-get install edgetpu-compiler" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "rfHlM9M2r4Xs", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Create Keras model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Acemn-Iqt0tR", | |
"colab_type": "code", | |
"outputId": "77bd34bb-1df1-436f-8756-f311cb75bbbb", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
} | |
}, | |
"source": [ | |
"%tensorflow_version 1.x\n", | |
"from __future__ import absolute_import, division, print_function, unicode_literals\n", | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"\n", | |
"print(tf.__version__)\n", | |
"\n", | |
"# Model parameters\n", | |
"INTERNAL_MODEL_WIDTH = 128\n", | |
"INTERNAL_MODEL_HEIGHT = 128\n", | |
"FILTER_CHANNEL = 64\n", | |
"FILTER_KERNEL_SIZE = 3\n", | |
"FILTER_LAYERS_NUM = 100\n", | |
"\n", | |
"# Create model\n", | |
"input0 = tf.keras.layers.Input(shape=(2,2,3))\n", | |
"pad0 = tf.keras.layers.ZeroPadding2D(padding=((int)(INTERNAL_MODEL_WIDTH/2)-1, (int)(INTERNAL_MODEL_HEIGHT/2)-1))(input0)\n", | |
"conv = pad0\n", | |
"for i in range(FILTER_LAYERS_NUM):\n", | |
" conv = tf.keras.layers.Conv2D(\n", | |
" filters=FILTER_CHANNEL,\n", | |
" kernel_size=(FILTER_KERNEL_SIZE,FILTER_KERNEL_SIZE),\n", | |
" strides=(1,1),\n", | |
" padding=\"same\",\n", | |
" activation=tf.keras.activations.relu\n", | |
" )(conv)\n", | |
"convOut = tf.keras.layers.Conv2D(\n", | |
" filters=1,\n", | |
" kernel_size=(1,1),\n", | |
" strides=((int)(INTERNAL_MODEL_WIDTH/2), (int)(INTERNAL_MODEL_HEIGHT/2)),\n", | |
" padding=\"valid\",\n", | |
" activation=tf.keras.activations.relu\n", | |
")(conv)\n", | |
"model = tf.keras.models.Model(inputs=[input0], outputs=[convOut])\n", | |
"\n", | |
"# Save model\n", | |
"model.summary()\n", | |
"model.compile(\n", | |
" optimizer=tf.keras.optimizers.Adam(),\n", | |
" loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n", | |
" metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]\n", | |
")\n", | |
"model_name = \"model_conv_\" + str(INTERNAL_MODEL_WIDTH) + \"x\" + str(INTERNAL_MODEL_HEIGHT) + \"x\" + str(FILTER_CHANNEL) + \"x\" + str(FILTER_KERNEL_SIZE) + \"x\" + str(FILTER_LAYERS_NUM)\n", | |
"model.save(model_name + \".h5\")\n", | |
"\n", | |
"print(\"FLOPs = \" + (str)(INTERNAL_MODEL_WIDTH * INTERNAL_MODEL_HEIGHT * FILTER_KERNEL_SIZE * FILTER_KERNEL_SIZE * FILTER_CHANNEL * FILTER_CHANNEL * FILTER_LAYERS_NUM * 2))" | |
], | |
"execution_count": 91, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"1.15.2\n", | |
"Model: \"model\"\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"input_1 (InputLayer) [(None, 2, 2, 3)] 0 \n", | |
"_________________________________________________________________\n", | |
"zero_padding2d (ZeroPadding2 (None, 128, 128, 3) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d (Conv2D) (None, 128, 128, 64) 1792 \n", | |
"_________________________________________________________________\n", | |
"conv2d_1 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_2 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_3 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_4 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_5 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_6 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_7 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_8 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_9 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_10 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_11 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_12 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_13 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_14 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_15 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_16 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_17 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_18 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_19 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_20 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_21 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_22 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_23 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_24 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_25 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_26 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_27 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_28 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_29 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_30 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_31 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_32 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_33 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_34 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_35 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_36 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_37 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_38 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_39 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_40 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_41 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_42 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_43 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_44 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_45 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_46 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_47 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_48 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_49 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_50 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_51 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_52 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_53 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_54 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_55 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_56 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_57 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_58 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_59 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_60 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_61 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_62 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_63 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_64 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_65 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_66 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_67 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_68 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_69 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_70 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_71 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_72 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_73 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_74 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_75 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_76 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_77 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_78 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_79 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_80 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_81 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_82 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_83 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_84 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_85 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_86 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_87 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_88 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_89 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_90 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_91 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_92 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_93 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_94 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_95 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_96 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_97 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_98 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_99 (Conv2D) (None, 128, 128, 64) 36928 \n", | |
"_________________________________________________________________\n", | |
"conv2d_100 (Conv2D) (None, 2, 2, 1) 65 \n", | |
"=================================================================\n", | |
"Total params: 3,657,729\n", | |
"Trainable params: 3,657,729\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n", | |
"FLOPs = 120795955200\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "SjOJGnCksQFR", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Convert to tflite model (full integer quantization)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "uaCM-iozuaYI", | |
"colab_type": "code", | |
"outputId": "f70e633f-1c82-45f3-cb4d-c2162219dc53", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 69 | |
} | |
}, | |
"source": [ | |
"%tensorflow_version 1.x\n", | |
"from __future__ import absolute_import, division, print_function, unicode_literals\n", | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"\n", | |
"## Prepara dataset generator for calibration\n", | |
"dummy_data_num = 2\n", | |
"dummy_data = []\n", | |
"for i in range(dummy_data_num):\n", | |
" dummy_data.append(np.random.rand(1, model.input_shape[1], model.input_shape[2], model.input_shape[3]))\n", | |
" dummy_data[i] = dummy_data[i].astype(np.float32)\n", | |
"\n", | |
"def representative_dataset_gen():\n", | |
" for i in range(dummy_data_num):\n", | |
" yield [dummy_data[i]]\n", | |
"\n", | |
"# Convert\n", | |
"# loaded_model = tf.keras.models.load_model(model_name + \".h5\")\n", | |
"# converter = tf.lite.TFLiteConverter.from_keras_model(loaded_model)\n", | |
"converter = tf.lite.TFLiteConverter.from_keras_model_file(model_name + \".h5\")\n", | |
"\n", | |
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", | |
"converter.representative_dataset = representative_dataset_gen\n", | |
"\n", | |
"# For full integer quantization\n", | |
"converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n", | |
"converter.inference_input_type = tf.uint8\n", | |
"converter.inference_output_type = tf.uint8\n", | |
"converter.experimental_new_converter = True # will be no need in the future\n", | |
"\n", | |
"tflite_model = converter.convert()\n", | |
"open(model_name + \".tflite\", \"wb\").write(tflite_model)" | |
], | |
"execution_count": 92, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Froze 202 variables.\n", | |
"INFO:tensorflow:Converted 202 variables to const ops.\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"3761240" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 92 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "SjQHI2x7sUjz", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Convert to edgetpu model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3meAgdciy8tE", | |
"colab_type": "code", | |
"outputId": "87b3067e-7723-4f6b-e374-41febc4560dd", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 278 | |
} | |
}, | |
"source": [ | |
"# Convert into edgetpu model\n", | |
"tflite_model_name = model_name + \".tflite\"\n", | |
"!edgetpu_compiler $tflite_model_name\n", | |
"\n", | |
"# Download to local\n", | |
"from google.colab import files\n", | |
"# files.download(model_name + \".tflite\")\n", | |
"files.download(model_name + \"_edgetpu.tflite\")" | |
], | |
"execution_count": 93, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Edge TPU Compiler version 2.1.302470888\n", | |
"\n", | |
"Model compiled successfully in 667 ms.\n", | |
"\n", | |
"Input model: model_conv_128x128x64x3x100.tflite\n", | |
"Input size: 3.59MiB\n", | |
"Output model: model_conv_128x128x64x3x100_edgetpu.tflite\n", | |
"Output size: 3.77MiB\n", | |
"On-chip memory used for caching model parameters: 3.56MiB\n", | |
"On-chip memory remaining for caching model parameters: 3.35MiB\n", | |
"Off-chip memory used for streaming uncached model parameters: 0.00B\n", | |
"Number of Edge TPU subgraphs: 1\n", | |
"Total number of operations: 104\n", | |
"Operation log: model_conv_128x128x64x3x100_edgetpu.log\n", | |
"See the operation log file for individual operation details.\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "n6gugXkKl4wZ", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Calculate FLOPS" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Or4v0viVgxmn", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"!pip install flopco-keras " | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "95CYxLlwg0_9", | |
"colab_type": "code", | |
"outputId": "7e339226-1769-4c53-fde4-fa09c52783e8", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 89 | |
} | |
}, | |
"source": [ | |
"from flopco_keras import FlopCoKeras\n", | |
"stats = FlopCoKeras(model)\n", | |
"print(f\"FLOPs: {stats.total_flops / 1000 / 1000 / 1000 } [GFLOPs]\")\n", | |
"print(f\"MACs: {stats.total_macs}\")\n", | |
"print(f\"Relative FLOPs: {stats.relative_flops}\")" | |
], | |
"execution_count": 94, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"FLOPs: 119.749476868 [GFLOPs]\n", | |
"MACs: 59822309632\n", | |
"Relative FLOPs: [0.0, 0.0, 0.0004816027719567541, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 0.01009614538302068, 4.308995859487448e-09]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "aWEKQ27wafKT", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Code for inference (running on local PC with Edge TPU)\n", | |
"```\n", | |
"sudo pip3 install https://dl.google.com/coral/python/tflite_runtime-2.1.0.post1-cp37-cp37m-linux_armv7l.whl\n", | |
"echo \"deb https://packages.cloud.google.com/apt coral-edgetpu-stable main\" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list\n", | |
"curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -\n", | |
"sudo apt-get update\n", | |
"sudo apt-get install libedgetpu1-max\n", | |
"\n", | |
"python3 run.py\n", | |
"```\n", | |
"\n", | |
"Result\n", | |
"\n", | |
"```\n", | |
"pi@raspberrypi:~/top $ sudo python3 run.py\n", | |
"Inference time = 0.031113 [sec]\n", | |
"3848.906 [GFLOPS]\n", | |
"```\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VWwDk9Hyaxky", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import platform\n", | |
"import time\n", | |
"import tflite_runtime.interpreter as tflite\n", | |
"\n", | |
"EDGETPU_SHARED_LIB = {\n", | |
"\t'Linux': 'libedgetpu.so.1',\n", | |
"\t'Darwin': 'libedgetpu.1.dylib',\n", | |
"\t'Windows': 'edgetpu.dll'\n", | |
"}[platform.system()]\n", | |
"\n", | |
"def make_interpreter(model_file):\n", | |
"\tmodel_file, *device = model_file.split('@')\n", | |
"\treturn tflite.Interpreter(\n", | |
"\t\tmodel_path=model_file,\n", | |
"\t\texperimental_delegates=[\n", | |
"\t\t\ttflite.load_delegate(EDGETPU_SHARED_LIB,\n", | |
"\t\t\t\t\t\t\t\t{'device': device[0]} if device else {})\n", | |
"\t\t])\n", | |
"\n", | |
"def main():\n", | |
"\tinterpreter = make_interpreter(\"model_conv_128x128x64x3x100_edgetpu.tflite\")\n", | |
"\tinterpreter.allocate_tensors()\n", | |
"\tinterpreter.invoke()\n", | |
"\n", | |
"\tnum_measurement = 100\n", | |
"\tstart = time.perf_counter()\n", | |
"\tfor _ in range(num_measurement):\n", | |
"\t\tinterpreter.invoke()\n", | |
"\tinference_time = (time.perf_counter() - start) / num_measurement\n", | |
"\tprint(\"Inference time = %.6f [sec]\" % inference_time)\n", | |
"\tprint(\"%.3f [GFLOPS]\" % (119.749476868 / inference_time))\n", | |
"\n", | |
"if __name__ == \"__main__\":\n", | |
"\tmain()\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "MSLD14DBeune", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"================" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "XkL8FXvCl9BV", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Calculate FLOPS (not working)\n", | |
"https://github.com/tensorflow/tensorflow/issues/32809" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "x6ra9jyl-zO3", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Calculate FLOPS\n", | |
"# https://stackoverflow.com/questions/49525776/how-to-calculate-a-mobilenet-flops-in-keras\n", | |
"# https://github.com/tensorflow/tensorflow/issues/32809\n", | |
"%tensorflow_version 2.x\n", | |
"def get_flops(model_h5_path):\n", | |
" session = tf.compat.v1.Session()\n", | |
" graph = tf.compat.v1.get_default_graph()\n", | |
" with graph.as_default():\n", | |
" with session.as_default():\n", | |
" # model = tf.keras.models.load_model(model_h5_path)\n", | |
" model = tf.keras.applications.mobilenet.MobileNet()\n", | |
" run_meta = tf.RunMetadata()\n", | |
" opts = tf.profiler.ProfileOptionBuilder.float_operation()\n", | |
"\n", | |
" # Optional: save printed results to file\n", | |
" # flops_log_path = os.path.join(tempfile.gettempdir(), 'tf_flops_log.txt')\n", | |
" # opts['output'] = 'file:outfile={}'.format(flops_log_path)\n", | |
"\n", | |
" # We use the Keras session graph in the call to the profiler.\n", | |
" flops = tf.profiler.profile(graph=graph,\n", | |
" run_meta=run_meta, cmd='op', options=opts)\n", | |
"\n", | |
" return flops.total_float_ops\n", | |
"\n", | |
"print((str)(get_flops(model_name + \".h5\") / 1000 / 1000 / 1000) + \" [GFLOPS]\")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zR_nwg0NeuI-", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def get_flops(model_h5_path):\n", | |
" session = tf.compat.v1.Session()\n", | |
" graph = tf.compat.v1.get_default_graph()\n", | |
" \n", | |
"\n", | |
" with graph.as_default():\n", | |
" with session.as_default():\n", | |
" model = tf.keras.models.load_model(model_h5_path)\n", | |
"\n", | |
" run_meta = tf.compat.v1.RunMetadata()\n", | |
" opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()\n", | |
" \n", | |
" # We use the Keras session graph in the call to the profiler.\n", | |
" flops = tf.compat.v1.profiler.profile(graph=graph,\n", | |
" run_meta=run_meta, cmd='op', options=opts)\n", | |
" \n", | |
" return flops.total_float_ops\n", | |
"print((str)(get_flops(model_name + \".h5\") / 1000 / 1000 / 1000) + \" [GFLOPS]\")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment