Skip to content

Instantly share code, notes, and snippets.

@ohtaman
Last active May 2, 2021 13:23
Show Gist options
  • Save ohtaman/c1cf119c463fd94b0da50feea320ba1e to your computer and use it in GitHub Desktop.
Save ohtaman/c1cf119c463fd94b0da50feea320ba1e to your computer and use it in GitHub Desktop.
EdgeTPU with Keras
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "EdgeTPU with Keras",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true,
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ohtaman/c1cf119c463fd94b0da50feea320ba1e/edgetpu-with-keras.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "auEGoMDKW8Z7",
"colab_type": "text"
},
"source": [
"Build a model by using Keras and convert it to the Edge TPU tflite file.\n",
"\n",
"### Install EdgeTPU Compiler"
]
},
{
"cell_type": "code",
"metadata": {
"id": "qLXmBB9bQoT7",
"colab_type": "code",
"outputId": "1c563341-3593-438b-b88b-beaa24ce111a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 379
}
},
"source": [
"%%bash\n",
"\n",
"echo \"deb https://packages.cloud.google.com/apt coral-edgetpu-stable main\" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list\n",
"sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 6A030B21BA07F4FB\n",
"\n",
"sudo apt update > /dev/null\n",
"sudo apt install edgetpu > /dev/null"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"deb https://packages.cloud.google.com/apt coral-edgetpu-stable main\n",
"Executing: /tmp/apt-key-gpghome.z6CQ29em1B/gpg.1.sh --keyserver keyserver.ubuntu.com --recv-keys 6A030B21BA07F4FB\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Warning: apt-key output should not be parsed (stdout is not a terminal)\n",
"gpg: key 6A030B21BA07F4FB: public key \"Google Cloud Packages Automatic Signing Key <[email protected]>\" imported\n",
"gpg: Total number processed: 1\n",
"gpg: imported: 1\n",
"\n",
"WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n",
"\n",
"\n",
"WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n",
"\n",
"debconf: unable to initialize frontend: Dialog\n",
"debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 76, <> line 5.)\n",
"debconf: falling back to frontend: Readline\n",
"debconf: unable to initialize frontend: Readline\n",
"debconf: (This frontend requires a controlling tty.)\n",
"debconf: falling back to frontend: Teletype\n",
"dpkg-preconfigure: unable to re-open stdin: \n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NDYJKuhfRoyo",
"colab_type": "text"
},
"source": [
"## Edge TPU with Keras\n",
"\n",
"build very simple model in this notebook.\n",
"\n",
"- data: Fashion MNISt\n",
"- input shape: 28 x 28\n",
"- output shape: 10\n",
"- hidden layers: only 1 dense layer"
]
},
{
"cell_type": "code",
"metadata": {
"id": "3aluRnFKRN9y",
"colab_type": "code",
"outputId": "df9ec9d9-6580-4f99-8476-caf6452c4ef5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"print(tf.__version__)"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"1.13.1\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "p40JjhzMSU92",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 161
},
"outputId": "a48582e4-a612-4b49-be62-0ffe3d893a14"
},
"source": [
"fashion_mnist = keras.datasets.fashion_mnist\n",
"(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()\n",
"\n",
"train_images = train_images / 255.0\n",
"test_images = test_images / 255.0"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz\n",
"32768/29515 [=================================] - 0s 0us/step\n",
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz\n",
"26427392/26421880 [==============================] - 0s 0us/step\n",
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz\n",
"8192/5148 [===============================================] - 0s 0us/step\n",
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz\n",
"4423680/4422102 [==============================] - 0s 0us/step\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vhC7xCG6S-yD",
"colab_type": "text"
},
"source": [
"### Build the model\n",
"\n",
"- define build_keras_model function since we have to build model 2 times (for train and eval)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "dNuqMwp5f_GX",
"colab_type": "code",
"colab": {}
},
"source": [
"def build_keras_model():\n",
" return keras.Sequential([\n",
" keras.layers.Flatten(input_shape=(28, 28)),\n",
" keras.layers.Dense(128, activation='relu'),\n",
" keras.layers.BatchNormalization(),\n",
" keras.layers.Dense(10, activation='softmax')\n",
" ])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Vk9oFSZXLl-",
"colab_type": "text"
},
"source": [
"## Train model and save it's checkpoints\n",
"\n",
"- Use new Session and Graph to ensure that we can use absolutory same name of variables for train and eval phase.\n",
"- call `tf.contrib.quantize.create_training_graph` after building model since we want to do Quantization Aware Training"
]
},
{
"cell_type": "code",
"metadata": {
"id": "jUuyAcC8ypv6",
"colab_type": "code",
"outputId": "16fbd8cf-6503-4e1e-c83a-ad0ef380c479",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 487
}
},
"source": [
"# train\n",
"train_graph = tf.Graph()\n",
"train_sess = tf.Session(graph=train_graph)\n",
"\n",
"keras.backend.set_session(train_sess)\n",
"with train_graph.as_default():\n",
" train_model = build_keras_model()\n",
" \n",
" tf.contrib.quantize.create_training_graph(input_graph=train_graph, quant_delay=100)\n",
" train_sess.run(tf.global_variables_initializer()) \n",
"\n",
" train_model.compile(\n",
" optimizer='adam',\n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy']\n",
" )\n",
" train_model.fit(train_images, train_labels, epochs=5)\n",
" \n",
" # save graph and checkpoints\n",
" saver = tf.train.Saver()\n",
" saver.save(train_sess, 'checkpoints')"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Colocations handled automatically by placer.\n",
"\n",
"WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
"For more information, please see:\n",
" * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
" * https://github.com/tensorflow/addons\n",
"If you depend on functionality not listed there, please file an issue.\n",
"\n",
"INFO:tensorflow:Inserting fake quant op activation_Mul_quant after batch_normalization_v1/batchnorm/mul_1\n",
"INFO:tensorflow:Inserting fake quant op activation_Add_quant after batch_normalization_v1/batchnorm/add_1\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.cast instead.\n",
"Epoch 1/5\n",
"60000/60000 [==============================] - 13s 222us/sample - loss: 0.4890 - acc: 0.8276\n",
"Epoch 2/5\n",
"60000/60000 [==============================] - 12s 207us/sample - loss: 0.4008 - acc: 0.8576\n",
"Epoch 3/5\n",
"60000/60000 [==============================] - 12s 202us/sample - loss: 0.3720 - acc: 0.8651\n",
"Epoch 4/5\n",
"60000/60000 [==============================] - 12s 202us/sample - loss: 0.3515 - acc: 0.8716\n",
"Epoch 5/5\n",
"60000/60000 [==============================] - 12s 203us/sample - loss: 0.3359 - acc: 0.8781\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "jWp9_I06ZjDo",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 71
},
"outputId": "44e4fd5e-9ee7-49d8-9ed9-e468a4cb9cd9"
},
"source": [
"with train_graph.as_default():\n",
" print('sample result of original model')\n",
" print(train_model.predict(test_images[:1]))"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"sample result of original model\n",
"[[3.3356025e-04 1.5380654e-05 1.2752986e-04 7.1626651e-05 5.9096528e-05\n",
" 4.0830608e-02 9.5574775e-05 7.2698116e-02 2.7520847e-04 8.8549328e-01]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dYOBvXw6X03l",
"colab_type": "text"
},
"source": [
"### Freeze model and save it\n",
"\n",
"- Create new Session and Graph\n",
"- Call `tf.contrib.quantize.create_eval_graph` and get graph_def after building model before saver.restore\n",
"- Call `saver.restore` to load the trained weights.\n",
" - saver.restore may add unneeded variables to the graph. So we have to get the graph_def before save.restore is called.\n",
"- We can use `tf.graph_util.convert_variables_to_constants` to freeze the graph_def"
]
},
{
"cell_type": "code",
"metadata": {
"id": "3tI1s0JngKN0",
"colab_type": "code",
"outputId": "feae64dd-5332-45bd-b7a7-19774c1568ee",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 289
}
},
"source": [
"# eval\n",
"eval_graph = tf.Graph()\n",
"eval_sess = tf.Session(graph=eval_graph)\n",
"\n",
"keras.backend.set_session(eval_sess)\n",
"\n",
"with eval_graph.as_default():\n",
" keras.backend.set_learning_phase(0)\n",
" eval_model = build_keras_model()\n",
" tf.contrib.quantize.create_eval_graph(input_graph=eval_graph)\n",
" eval_graph_def = eval_graph.as_graph_def()\n",
" saver = tf.train.Saver()\n",
" saver.restore(eval_sess, 'checkpoints')\n",
"\n",
" frozen_graph_def = tf.graph_util.convert_variables_to_constants(\n",
" eval_sess,\n",
" eval_graph_def,\n",
" [eval_model.output.op.name]\n",
" )\n",
"\n",
" with open('frozen_model.pb', 'wb') as f:\n",
" f.write(frozen_graph_def.SerializeToString())"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Inserting fake quant op activation_Mul_quant after batch_normalization_v1/batchnorm/mul_1\n",
"INFO:tensorflow:Inserting fake quant op activation_Add_quant after batch_normalization_v1/batchnorm/add_1\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use standard file APIs to check for files with this prefix.\n",
"INFO:tensorflow:Restoring parameters from checkpoints\n",
"WARNING:tensorflow:From <ipython-input-7-995fccbe9e12>:17: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.compat.v1.graph_util.convert_variables_to_constants\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/graph_util_impl.py:245: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.compat.v1.graph_util.extract_sub_graph\n",
"INFO:tensorflow:Froze 20 variables.\n",
"INFO:tensorflow:Converted 20 variables to const ops.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tETygOV_Y_cX",
"colab_type": "text"
},
"source": [
"### Generate tflite file\n",
"\n",
"- use QUANTIZED_UINT8 option\n",
"- Quantization Aware training adds min/max information. So we don't need default_ranges_min default_ranges_max \n",
"- We don't need call freeze_graph.py since the graph is already freezed."
]
},
{
"cell_type": "code",
"metadata": {
"id": "APsbHmt7izT8",
"colab_type": "code",
"outputId": "5c6beaae-afa0-4f66-a736-f57b048159bf",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 325
}
},
"source": [
"%%bash\n",
"\n",
"tflite_convert \\\n",
" --output_file=model.tflite \\\n",
" --graph_def_file=frozen_model.pb \\\n",
" --inference_type=QUANTIZED_UINT8 \\\n",
" --input_arrays=flatten_input \\\n",
" --output_arrays=dense_1/Softmax \\\n",
" --mean_values=0 \\\n",
" --std_dev_values=255"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"2019-06-07 19:27:39.747839: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2300000000 Hz\n",
"2019-06-07 19:27:39.748212: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x55698b291760 executing computations on platform Host. Devices:\n",
"2019-06-07 19:27:39.748240: I tensorflow/compiler/xla/service/service.cc:158] StreamExecutor device (0): <undefined>, <undefined>\n",
"2019-06-07 19:27:39.900149: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2019-06-07 19:27:39.900697: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x55698b291340 executing computations on platform CUDA. Devices:\n",
"2019-06-07 19:27:39.900729: I tensorflow/compiler/xla/service/service.cc:158] StreamExecutor device (0): Tesla T4, Compute Capability 7.5\n",
"2019-06-07 19:27:39.901220: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1433] Found device 0 with properties: \n",
"name: Tesla T4 major: 7 minor: 5 memoryClockRate(GHz): 1.59\n",
"pciBusID: 0000:00:04.0\n",
"totalMemory: 14.73GiB freeMemory: 14.33GiB\n",
"2019-06-07 19:27:39.901246: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1512] Adding visible gpu devices: 0\n",
"2019-06-07 19:27:41.341193: I tensorflow/core/common_runtime/gpu/gpu_device.cc:984] Device interconnect StreamExecutor with strength 1 edge matrix:\n",
"2019-06-07 19:27:41.341263: I tensorflow/core/common_runtime/gpu/gpu_device.cc:990] 0 \n",
"2019-06-07 19:27:41.341276: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1003] 0: N \n",
"2019-06-07 19:27:41.341574: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.\n",
"2019-06-07 19:27:41.341674: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 13853 MB memory) -> physical GPU (device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5)\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WIKWO5MuZk8f",
"colab_type": "text"
},
"source": [
"### Check generated tflite file.\n",
".\n",
"- Use TFLiteInterpreter to check the generated file is valid"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zI0zfQTL-p5U",
"colab_type": "code",
"outputId": "c277b2e7-ebef-4cbb-cbcb-02448b0fe78a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 73
}
},
"source": [
"# load TFLite file\n",
"interpreter = tf.lite.Interpreter(model_path=f'model.tflite')\n",
"# Allocate memory. \n",
"interpreter.allocate_tensors()\n",
"\n",
"# get some informations .\n",
"input_details = interpreter.get_input_details()\n",
"output_details = interpreter.get_output_details()\n",
"\n",
"print(input_details)\n",
"print(output_details)"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"[{'name': 'flatten_input', 'index': 11, 'shape': array([ 1, 28, 28], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.003921568859368563, 0)}]\n",
"[{'name': 'dense_1/Softmax', 'index': 9, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.00390625, 0)}]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "16MsnRzrZ0Yk",
"colab_type": "text"
},
"source": [
"- I'm not sure how to use quantization attribute in input/output_details. But maybe\n",
" - If quantization attribute is (a, b), then the input data f should be transform to (f/a + b) and casted to uint8"
]
},
{
"cell_type": "code",
"metadata": {
"id": "R_O6nfR4XCmo",
"colab_type": "code",
"colab": {}
},
"source": [
"def quantize(detail, data):\n",
" shape = detail['shape']\n",
" dtype = detail['dtype']\n",
" a, b = detail['quantization']\n",
" \n",
" return (data/a + b).astype(dtype).reshape(shape)\n",
"\n",
"\n",
"def dequantize(detail, data):\n",
" a, b = detail['quantization']\n",
" \n",
" return (data - b)*a"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jJuqJna_vgna",
"colab_type": "code",
"outputId": "0e4dc5c9-117d-430d-fec7-8cbd1d8e1507",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 71
}
},
"source": [
"quantized_input = quantize(input_details[0], test_images[:1])\n",
"interpreter.set_tensor(input_details[0]['index'], quantized_input)\n",
"\n",
"interpreter.invoke()\n",
"\n",
"# The results are stored on 'index' of output_details\n",
"quantized_output = interpreter.get_tensor(output_details[0]['index'])\n",
"\n",
"print('sample result of quantized model')\n",
"print(dequantize(output_details[0], quantized_output))"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"sample result of quantized model\n",
"[[0. 0. 0. 0. 0. 0.04296875\n",
" 0. 0.07421875 0. 0.8828125 ]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9rFwYYH5aZjy",
"colab_type": "text"
},
"source": [
"### Compile the tflite file using EdgeTPU Compiler "
]
},
{
"cell_type": "code",
"metadata": {
"id": "GmSX6C6RxAZh",
"colab_type": "code",
"outputId": "96982a5b-ec8e-46fa-f924-0251ab9c9605",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 305
}
},
"source": [
"%%bash\n",
"\n",
"edgetpu_compiler 'model.tflite'"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": [
"Edge TPU Compiler version 1.0.249710469\n",
"\n",
"Model compiled successfully in 21 ms.\n",
"\n",
"Input model: model.tflite\n",
"Input size: 102.27KiB\n",
"Output model: model_edgetpu.tflite\n",
"Output size: 152.55KiB\n",
"On-chip memory available for caching model parameters: 8.01MiB\n",
"On-chip memory used for caching model parameters: 106.75KiB\n",
"Off-chip memory used for streaming uncached model parameters: 4.00KiB\n",
"Number of Edge TPU subgraphs: 1\n",
"Total number of operations: 5\n",
"Operation log: model_edgetpu.log\n",
"See the operation log file for individual operation details.\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO: Initialized TensorFlow Lite runtime.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tQ9qGJwZf_AV",
"colab_type": "text"
},
"source": [
"We can download the generated file."
]
},
{
"cell_type": "code",
"metadata": {
"id": "ygUHjCLOz41Q",
"colab_type": "code",
"colab": {}
},
"source": [
"from google.colab import files\n",
"\n",
"files.download('model_edgetpu.tflite')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "RMFiIwlk3r40",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment