Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dusskapark/8adf27e4b8c6e7c392811d24547bf27e to your computer and use it in GitHub Desktop.
Save dusskapark/8adf27e4b8c6e7c392811d24547bf27e to your computer and use it in GitHub Desktop.
Model Maker Object Detection for RICO.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/dusskapark/8adf27e4b8c6e7c392811d24547bf27e/model-maker-object-detection-for-rico.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gf2if_fGDaWc"
},
"source": [
"##### References \n",
"\n",
"> This colab notebook were made by modifying references below: \n",
"> - [Train a custom object detection model using your data](https://youtu.be/-ZyFYniGUsw)\n",
"> - [Model Maker Object Detection for Android Figurine](https://colab.research.google.com/github/khanhlvg/tflite_raspberry_pi/blob/main/object_detection/Train_custom_model_tutorial.ipynb)\n",
"> - [Object Detection with TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/tutorials/model_maker_object_detection)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "jrmj83afDJrv"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PpJEzDG6DK2Q"
},
"source": [
"# Train a custom object detection model with TensorFlow Lite Model Maker\n",
"\n",
"In this colab notebook, you'll learn how to use the [TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/guide/model_maker) to train a custom object detection model to detect UI objects and how to convert the TF lite model with TFJS.\n",
"\n",
"The Model Maker library uses *transfer learning* to simplify the process of training a TensorFlow Lite model using a custom dataset. Retraining a TensorFlow Lite model with your own custom dataset reduces the amount of training data required and will shorten the training time.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BRYjtwRZGBOI"
},
"source": [
"## Preparation\n",
"\n",
"### Install the required packages\n",
"Start by installing the required packages, including the Model Maker package from the [GitHub repo](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker) and the pycocotools library you'll use for evaluation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "35BJmtVpAP_n",
"outputId": "7864c1ac-1760-4f3a-c8ea-a3f98e58b332"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n",
"Note: you may need to restart the kernel to use updated packages.\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install -q tflite-model-maker\n",
"%pip install -q tflite-support\n",
"%pip install -q pycocotools"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "prQ86DdtD317"
},
"source": [
"Import the required packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l4QQTXHHATDS"
},
"outputs": [],
"source": [
"from absl import logging\n",
"import tensorflow as tf\n",
"from tflite_support import metadata\n",
"import numpy as np\n",
"import os\n",
"\n",
"from tflite_model_maker.config import ExportFormat, QuantizationConfig\n",
"from tflite_model_maker import model_spec\n",
"from tflite_model_maker import object_detector\n",
"\n",
"\n",
"assert tf.__version__.startswith('2')\n",
"\n",
"tf.get_logger().setLevel('ERROR')\n",
"logging.set_verbosity(logging.ERROR)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3g6aQvXsD78P"
},
"source": [
"### Prepare the dataset\n",
"\n",
"In this notebook, we're going to rearrange and use the RICO dataset. If you would like to manually build training and validation datasets, please click [this link](https://interactionmining.org/rico) and download RICO manually.\n",
"\n",
"#### Downloading data from RICO dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8AGg7D4JAV62"
},
"outputs": [],
"source": [
"!curl -L \"https://storage.googleapis.com/crowdstf-rico-uiuc-4540/rico_dataset_v0.1/unique_uis.tar.gz\" > jpg.tar.gz; tar -zxvf jpg.tar.gz; rm jpg.tar.gz\n",
"!curl -L \"https://storage.googleapis.com/crowdstf-rico-uiuc-4540/rico_dataset_v0.1/semantic_annotations.zip\" > json.zip; unzip json.zip; rm json.zip"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-6tDEWQSVtqb"
},
"source": [
"Download two compressed files from RICO. Each file contains pairs of images and JSON files. We save space by deleting all files except for the original image files and annotation files we need."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yurpxo2uVwBl"
},
"outputs": [],
"source": [
"# remove unused pairs\n",
"!find combined/. -name '*.json' -type f -delete\n",
"!find combined/. -name \"*.jpg\" | wc -l\n",
"!find semantic_annotations/. -name '*.png' -type f -delete\n",
"!find semantic_annotations/. -name \"*.json\" | wc -l"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JWD2NiCKVzDP"
},
"source": [
"#### Scaled images \n",
"\n",
"Most of RICO's screenshots are high-resolution images at 1440 × 2560 pixels. If you use these directly, they will use a lot of resources with regards to the GPU and memory within Google Colab's training environment. \n",
"\n",
"So we'll reduce all images to 640px height JPG. Later, we're going to change the value inside annotation files too.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RDpic-LQVz6B"
},
"outputs": [],
"source": [
"from PIL import Image\n",
"import os\n",
"\n",
"raw_path = './combined/' # source image path\n",
"data_path = './jpg/' # Resized image path\n",
"\n",
"# Start resize --------------------\n",
"# If there is no data_path, create\n",
"if not os.path.exists(data_path):\n",
" os.mkdir(data_path)\n",
"\n",
"# Specify a list of all images in the source image path\n",
"data_list = os.listdir(raw_path)\n",
"print(len(data_list))\n",
"\n",
"# Save all images after resizing\n",
"for name in data_list:\n",
" im = Image.open(raw_path + name)\n",
" im = im.resize((360, 640))\n",
" im.save(data_path + name)\n",
" print('end ::: ' + data_path + name)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qnSNCdVpV7_B"
},
"source": [
"#### Convert to XML\n",
"\n",
"We'll extract only the necessary information such as such as the bounding box, filename, and component name from the JSON files and convert them to xml format.\n",
"\n",
"During this conversion, I discovered some data had unexpected errors:\n",
"\n",
"- Bounding box (bndbox) is sometimes negative\n",
"- Component’s xmax or ymax value is sometimes greater than the overall width and height values of the screenshot\n",
"- xmin is sometimes greater than xmax or ymin is greater than ymax\n",
"- And so on…\n",
"\n",
"So, I skipped all of those kinda errored items. Please check the isvalidbdnbox function below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A6P3not9XWU3"
},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"from xml.etree.ElementTree import Element, SubElement\n",
"\n",
"def beautify(elem, indent=0):\n",
" \"\"\"\n",
" xml 트리를 문자열로 변환합니다.\n",
" Converts the XML tree to a string.\n",
"\n",
" :param elem: xml element\n",
" :param indent: Indent Level to display in front\n",
" \"\"\"\n",
" result0 = f\"{' ' * indent}<{elem.tag}>\"\n",
" # 값이 있는 태그면 값을 바로 출력\n",
" # If there is a value in the tag, immediately print the value\n",
" if elem.text is not None:\n",
" result0 += elem.text + f\"</{elem.tag}>\\n\"\n",
" # 값이 없고 자식 노드가 있으면 재귀 호출로 출력합니다.\n",
" # If the tag has no value and has child nodes, call the recursive function.\n",
" else:\n",
" result0 += \"\\n\"\n",
" for _child in elem:\n",
" result0 += beautify(_child, indent + 1)\n",
" result0 += f\"{' ' * indent}</{elem.tag}>\\n\"\n",
" return result0\n",
"\n",
"def isvalidbdnbox(xmin,xmax,ymin,ymax):\n",
" \"\"\"\n",
" Validate the bnd box.\n",
" 다중 if를 추가해서 bnd 박스를 검증해야 합니다.\n",
" \"\"\"\n",
" # bounding box(bndbox)가 음수이면 안됩니다. \n",
" # bounding box(bndbox) cannot be negative.\n",
" if(xmin<0 or xmax<0 or ymin<0 or ymax<0):\n",
" print(\"bndbox cannot be negative.\")\n",
" return False\n",
" # bndbox의 xmax 값은 1440, ymax 값은 2560을 넘으면 안됩니다. \n",
" # The xmax value of bndbox cannot exceed 1440 and the ymax value cannot exceed 2560.\n",
" if(xmax>1440 or ymax>2560):\n",
" print(\"bndbox cannot exceed the screenshot\")\n",
" return False\n",
"\n",
" # xmin 값은 xmax 보다 클 수 없습니다. 또는 ymin 값은 ymax 보다 클 수 없습니다. \n",
" # xmin cannot be greater than xmax or ymin cannot be greater than ymax.\n",
" \n",
" if(xmin>xmax or ymin>ymax):\n",
" print(\"xmin>xmax or ymin>ymax\")\n",
" return False\n",
"\n",
" if(xmin==xmax or ymin==ymax):\n",
" print(\"xmin=xmax or ymin=ymax\")\n",
" return False\n",
" \n",
" \n",
" return True\n",
"\n",
"def recursive(child, result_out):\n",
" \"\"\"\n",
" 원하는 역할을 하기 위해서 재귀호출을 할 수 있는 함수를 생성합니다.\n",
" Create a function that makes a recursive call.\n",
" \"\"\"\n",
" obj = Element(\"object\")\n",
"\n",
" # Set bounds\n",
" bounds = child['bounds']\n",
"\n",
" # Set name\n",
" SubElement(obj, \"name\").text = child['componentLabel']\n",
" # Set difficult, truncated, pose\n",
" SubElement(obj, \"difficult\").text = '0'\n",
" SubElement(obj, \"truncated\").text = 'Unspecified'\n",
" SubElement(obj, \"pose\").text = 'Undefined'\n",
"\n",
" # Set bndbox \n",
" bndbox = SubElement(obj, \"bndbox\")\n",
" xmin=bounds[0]\n",
" ymin=bounds[1]\n",
" xmax=bounds[2]\n",
" ymax=bounds[3]\n",
" \n",
" if(isvalidbdnbox(xmin,xmax,ymin,ymax)):\n",
" SubElement(bndbox, \"xmin\").text = str(round(xmin/4))\n",
" SubElement(bndbox, \"ymin\").text = str(round(ymin/4))\n",
" SubElement(bndbox, \"xmax\").text = str(round(xmax/4))\n",
" SubElement(bndbox, \"ymax\").text = str(round(ymax/4))\n",
" result_out.append(beautify(obj))\n",
"\n",
" \n",
" # 생성한 object 태그를 문자열로 변환해서 추가합니다.\n",
" # Convert the created object tag to a string and add it.\n",
" if 'children' not in child:\n",
" return\n",
"\n",
" # 자식 노드가 있는 경우 자식 노드에 대해 재귀 호출을 수행합니다.\n",
" # If there is a child node, make a recursive call to the child node.\n",
" for ch in child.get('children', []):\n",
" recursive(ch, result_out)\n",
"\n",
"\n",
"def json2xml(infile, outfile):\n",
" \"\"\"\n",
" json2xml function\n",
" param infile :\n",
" ourfile :\n",
" \"\"\"\n",
" result_out = []\n",
" imgName = infile.replace(\"./semantic_annotations/\", \"\")\n",
" imgName = imgName.replace(\"json\", \"jpg\")\n",
"\n",
" # Read the file\n",
" with open(infile, \"r\", encoding=\"UTF-8\") as f:\n",
" data = json.load(f)\n",
"\n",
" children = data['children']\n",
"\n",
" # 자식 노드에 대해 재귀 호출을 수행합니다.\n",
" # Make a recursive call on child nodes.\n",
" for child in children:\n",
" recursive(child, result_out)\n",
"\n",
" # 그리고 해당 결과를 파일로 저장합니다.\n",
" # And save the result to the XML file.\n",
" with open(outfile, \"w\", encoding=\"UTF-8\") as f:\n",
" f.write(\"\".join(\"<annotation><folder />\"))\n",
" f.write(\"\".join(\"<filename>\" + imgName + \"</filename>\\n\"\n",
" \"<path>\" + imgName + \"</path>\\n\"+\n",
" \"<source><database>RICO</database></source><size><width>360</width><height>640</height><depth>3</depth></size><segmented>0</segmented> \"))\n",
" f.write(\"\".join(result_out))\n",
" f.write(\"\".join(\"</annotation>\"))\n",
"\n",
"\n",
"def search(mypath):\n",
" onlyfiles = [f for f in os.listdir(mypath)\n",
" if os.path.isfile(os.path.join(mypath, f))]\n",
" onlyfiles.sort() \n",
" return onlyfiles\n",
"\n",
"\n",
"def main():\n",
" \"\"\"\n",
" 메인 함수\n",
" The main function \n",
" \"\"\"\n",
" data_path = './xml/'\n",
" if not os.path.exists(data_path):\n",
" os.mkdir(data_path)\n",
"\n",
"\n",
" files = os.listdir('./semantic_annotations')\n",
" # print(files)\n",
" for infile in files:\n",
" infile = f'./semantic_annotations/{infile}'\n",
" outfile = infile.replace(\"semantic_annotations\", \"xml\").replace(\"json\", \"xml\")\n",
" print(infile, outfile)\n",
" json2xml(infile, outfile)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" main()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5dXpiFSoV_8f"
},
"source": [
"#### Generate Label Map\n",
"\n",
"Next, we should generate the `label_map.pbtxt` based on XML file and also set the `label_map` array together.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Sg8P9fdBWCzj",
"outputId": "95484b7f-2566-4c91-eaf1-6e046bb65b4d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Image', 'Input', 'On/Off Switch', 'Date Picker', 'Map View', 'List Item', 'Advertisement', 'Card', 'Checkbox', 'Drawer', 'Web View', 'Radio Button', 'Video', 'Button Bar', 'Text', 'Bottom Navigation', 'Toolbar', 'Number Stepper', 'Text Button', 'Pager Indicator', 'Icon', 'Slider', 'Modal', 'Background Image', 'Multi-Tab']\n"
]
}
],
"source": [
"import os\n",
"import xml.etree.ElementTree as ET\n",
"\n",
"obj = []\n",
"\n",
"for filename in os.listdir(\"./xml/\"):\n",
" # with open(os.path.join(\"xml\", filename), 'r') as f:\n",
" tree = ET.parse(os.path.join(\"./xml/\", filename))\n",
" root = tree.getroot()\n",
" object = root.findall(\"object\")\n",
" name = [x.findtext(\"name\") for x in object]\n",
"\n",
" for i in name:\n",
" obj.append(i)\n",
"\n",
"obj_unique = list(set(obj))\n",
"\n",
"pbtxt = \"\"\n",
"label_map = []\n",
"for i in range(len(obj_unique)):\n",
" pbtxt += \"item {\\n name: \\\"\"+obj_unique[i]+\"\\\",\\n id: \"+str(i+1)+\"\\n}\\n\"+\"\\n\"\n",
" label_map.append(obj_unique[i])\n",
" \n",
"with open(\"label_map.pbtxt\", \"w\", encoding=\"utf-8\") as f:\n",
" f.write(pbtxt)\n",
"\n",
"print(label_map)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T9hbPvymZofa"
},
"source": [
"#### Partition the dataset\n",
"\n",
"Next, We're going to split our dataset into the desired training and testing subsets. Typically, the ratio is 9:1. 90% of the images are used for training and the rest 10% is maintained for testing, but you can chose whatever ratio suits your needs.\n",
"\n",
"[Lyudmil Vladimirov](https://github.com/sglvladi) has published a great code example on [splitting the dataset](https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/training.html#partition-the-dataset)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jaEobd0sZrBe"
},
"outputs": [],
"source": [
"import os\n",
"import re\n",
"from shutil import copyfile\n",
"import argparse\n",
"import math\n",
"import random\n",
"\n",
"\n",
"def iterate_dir(source, dest, ratio, xml_source):\n",
" source = source.replace('\\\\', '/')\n",
" xml_source = xml_source.replace('\\\\', '/')\n",
" dest = dest.replace('\\\\', '/')\n",
" train_dir = os.path.join(dest, 'train')\n",
" test_dir = os.path.join(dest, 'test')\n",
"\n",
" if not os.path.exists(train_dir):\n",
" os.makedirs(train_dir)\n",
" if not os.path.exists(test_dir):\n",
" os.makedirs(test_dir)\n",
"\n",
" images = [f for f in os.listdir(source)\n",
" if re.search(r'([a-zA-Z0-9\\s_\\\\.\\-\\(\\):])+(?i)(.jpg|.jpeg|.png)$', f)]\n",
"\n",
" num_images = len(images)\n",
" num_test_images = math.ceil(ratio*num_images)\n",
"\n",
" for i in range(num_test_images):\n",
" idx = random.randint(0, len(images)-1)\n",
" filename = images[idx]\n",
" copyfile(os.path.join(source, filename),os.path.join(test_dir, filename))\n",
" \n",
" xml_filename = os.path.splitext(filename)[0]+'.xml'\n",
" copyfile(os.path.join(xml_source, xml_filename),os.path.join(test_dir, xml_filename))\n",
" images.remove(images[idx])\n",
"\n",
" for filename in images:\n",
" copyfile(os.path.join(source, filename),os.path.join(train_dir, filename))\n",
" \n",
" xml_filename = os.path.splitext(filename)[0]+'.xml'\n",
" copyfile(os.path.join(xml_source, xml_filename),os.path.join(train_dir, xml_filename))\n",
" \n",
"def main():\n",
"\n",
" imageDir = './jpg/'\n",
" outputDir = './'\n",
" ratio = 0.1\n",
" xmlDir = './xml/'\n",
"\n",
" # Now we are ready to start the iteration\n",
" iterate_dir(imageDir, outputDir, ratio, xmlDir)\n",
"\n",
"\n",
"if __name__ == '__main__':\n",
" main()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XYvfmDYRdm3_",
"outputId": "4a436c7b-c34d-4a4c-e4e3-b7e2a9ea0799"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"119269\n",
"13255\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"UsageError: Line magic function `%rm` not found.\n"
]
}
],
"source": [
"# (optional) Leave only the necessary files.\n",
"\n",
"!find train/. | wc -l\n",
"!find test/. | wc -l\n",
"\n",
"%rm -rf semantic_annotations combined jpg\n",
"# %rm -rf test train"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yxh3KInCFeB-"
},
"source": [
"## Train the object detection model\n",
"\n",
"### Step 1: Load the dataset\n",
"\n",
"* Images in `train_data` is used to train the custom object detection model.\n",
"* Images in `val_data` is used to check if the model can generalize well to new images that it hasn't seen before."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "p0UAF9Qqpu5e"
},
"outputs": [],
"source": [
"# print(label_map)\n",
"train_data = object_detector.DataLoader.from_pascal_voc(\n",
" images_dir='./train',\n",
" annotations_dir='./train',\n",
" label_map=label_map\n",
")\n",
"\n",
"val_data = object_detector.DataLoader.from_pascal_voc(\n",
" images_dir='./test',\n",
" annotations_dir='./test',\n",
" label_map=label_map\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UNRhB8N7GHXj"
},
"source": [
"### Step 2: Select a model architecture\n",
"\n",
"EfficientDet-Lite[0-4] are a family of mobile/IoT-friendly object detection models derived from the [EfficientDet](https://arxiv.org/abs/1911.09070) architecture.\n",
"\n",
"Here is the performance of each EfficientDet-Lite models compared to each others.\n",
"\n",
"| Model architecture | Size(MB)* |Latency(ms)**\t| Average Precision*** |\n",
"|--------------------|-----------|--------------|----------------------|\n",
"| EfficientDet-Lite0 | 4.4 | \t37 | 25.69% |\n",
"| EfficientDet-Lite1 | 5.8 | 49 | 30.55% |\n",
"| EfficientDet-Lite2 | 7.2 | 69 | 33.97% |\n",
"| EfficientDet-Lite3 | 11.4 | 116 | 37.70% |\n",
"| EfficientDet-Lite4 | 19.9 | 260 | 41.96% |\n",
"\n",
"<i> * Size of the integer quantized models. <br/>\n",
"** Latency measured on Pixel 4 using 4 threads on CPU.<br/>\n",
"*** Average Precision is the mAP (mean Average Precision) on the COCO 2017 validation dataset.\n",
"</i>\n",
"\n",
"In this notebook, we use EfficientDet-Lite0 to train our model. You can choose other model architectures depending on whether speed or accuracy is more important to you. \n",
"\n",
"**note:** You might need to pass the max instance number as an hparam when the model is created due to huge volume of RICO dataset. Set the `max_instances_per_image` slightly higher than the max number of objects you expect to see in an image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GZOojrDHAY1J"
},
"outputs": [],
"source": [
"spec = object_detector.EfficientDetSpec(model_name='efficientdet-lite0', uri='https://tfhub.dev/tensorflow/efficientdet/lite0/feature-vector/1', hparams={'max_instances_per_image': 8000})"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5aeDU4mIM4ft"
},
"source": [
"### Step 3: Train the TensorFlow model with the training data.\n",
"\n",
"* Set `epochs = 20`, which means it will go through the training dataset 20 times. You can look at the validation accuracy during training and stop when you see validation loss (`val_loss`) stop decreasing to avoid overfitting.\n",
"* Set `batch_size = 16` here so you will see that it takes 3,727 steps to go through the all screenshots in the training dataset.\n",
"* Set `train_whole_model=True` to fine-tune the whole model instead of just training the head layer to improve accuracy. The trade-off is that it may take longer to train the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 471
},
"id": "_MClfpsJAfda",
"outputId": "3ff302ef-20a6-400b-970a-f159e8214e34"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"3727/3727 [==============================] - 17525s 5s/step - det_loss: 1.0959 - cls_loss: 0.6612 - box_loss: 0.0087 - reg_l2_loss: 0.0672 - loss: 1.1631 - learning_rate: 0.0140 - gradient_norm: 1.4641 - val_det_loss: 1.0653 - val_cls_loss: 0.6432 - val_box_loss: 0.0084 - val_reg_l2_loss: 0.0673 - val_loss: 1.1326\n",
"Epoch 2/20\n",
"3727/3727 [==============================] - 17482s 5s/step - det_loss: 0.9268 - cls_loss: 0.5615 - box_loss: 0.0073 - reg_l2_loss: 0.0674 - loss: 0.9942 - learning_rate: 0.0197 - gradient_norm: 1.2754 - val_det_loss: 0.9822 - val_cls_loss: 0.5799 - val_box_loss: 0.0080 - val_reg_l2_loss: 0.0675 - val_loss: 1.0497\n",
"Epoch 3/20\n",
"3727/3727 [==============================] - 17462s 5s/step - det_loss: 0.8908 - cls_loss: 0.5412 - box_loss: 0.0070 - reg_l2_loss: 0.0675 - loss: 0.9582 - learning_rate: 0.0191 - gradient_norm: 1.2383 - val_det_loss: 0.9547 - val_cls_loss: 0.5637 - val_box_loss: 0.0078 - val_reg_l2_loss: 0.0674 - val_loss: 1.0221\n",
"Epoch 4/20\n",
"3727/3727 [==============================] - 17494s 5s/step - det_loss: 0.8632 - cls_loss: 0.5253 - box_loss: 0.0068 - reg_l2_loss: 0.0673 - loss: 0.9305 - learning_rate: 0.0184 - gradient_norm: 1.2212 - val_det_loss: 0.9452 - val_cls_loss: 0.5576 - val_box_loss: 0.0078 - val_reg_l2_loss: 0.0672 - val_loss: 1.0124\n",
"Epoch 5/20\n",
"3727/3727 [==============================] - 18025s 5s/step - det_loss: 0.8442 - cls_loss: 0.5165 - box_loss: 0.0066 - reg_l2_loss: 0.0670 - loss: 0.9112 - learning_rate: 0.0173 - gradient_norm: 1.2118 - val_det_loss: 0.8487 - val_cls_loss: 0.5037 - val_box_loss: 0.0069 - val_reg_l2_loss: 0.0668 - val_loss: 0.9155\n",
"Epoch 6/20\n",
"3727/3727 [==============================] - 17874s 5s/step - det_loss: 0.8329 - cls_loss: 0.5094 - box_loss: 0.0065 - reg_l2_loss: 0.0666 - loss: 0.8995 - learning_rate: 0.0161 - gradient_norm: 1.2254 - val_det_loss: 0.9447 - val_cls_loss: 0.5774 - val_box_loss: 0.0073 - val_reg_l2_loss: 0.0663 - val_loss: 1.0111\n",
"Epoch 7/20\n",
"3727/3727 [==============================] - 17848s 5s/step - det_loss: 0.8182 - cls_loss: 0.5023 - box_loss: 0.0063 - reg_l2_loss: 0.0661 - loss: 0.8844 - learning_rate: 0.0148 - gradient_norm: 1.2439 - val_det_loss: 0.9011 - val_cls_loss: 0.5337 - val_box_loss: 0.0073 - val_reg_l2_loss: 0.0659 - val_loss: 0.9670\n",
"Epoch 8/20\n",
"3727/3727 [==============================] - 17838s 5s/step - det_loss: 0.8099 - cls_loss: 0.4967 - box_loss: 0.0063 - reg_l2_loss: 0.0656 - loss: 0.8755 - learning_rate: 0.0132 - gradient_norm: 1.2692 - val_det_loss: 0.9309 - val_cls_loss: 0.5665 - val_box_loss: 0.0073 - val_reg_l2_loss: 0.0653 - val_loss: 0.9963\n",
"Epoch 9/20\n",
"3727/3727 [==============================] - 17868s 5s/step - det_loss: 0.7989 - cls_loss: 0.4908 - box_loss: 0.0062 - reg_l2_loss: 0.0651 - loss: 0.8640 - learning_rate: 0.0116 - gradient_norm: 1.2945 - val_det_loss: 0.9429 - val_cls_loss: 0.5695 - val_box_loss: 0.0075 - val_reg_l2_loss: 0.0648 - val_loss: 1.0077\n",
"Epoch 10/20\n",
"3727/3727 [==============================] - 18372s 5s/step - det_loss: 0.7906 - cls_loss: 0.4854 - box_loss: 0.0061 - reg_l2_loss: 0.0645 - loss: 0.8550 - learning_rate: 0.0100 - gradient_norm: 1.3375 - val_det_loss: 0.9003 - val_cls_loss: 0.5316 - val_box_loss: 0.0074 - val_reg_l2_loss: 0.0642 - val_loss: 0.9645\n",
"Epoch 11/20\n",
"3727/3727 [==============================] - 18466s 5s/step - det_loss: 0.7789 - cls_loss: 0.4792 - box_loss: 0.0060 - reg_l2_loss: 0.0639 - loss: 0.8428 - learning_rate: 0.0084 - gradient_norm: 1.3638 - val_det_loss: 0.8097 - val_cls_loss: 0.4916 - val_box_loss: 0.0064 - val_reg_l2_loss: 0.0636 - val_loss: 0.8734\n",
"Epoch 12/20\n",
"3727/3727 [==============================] - 18475s 5s/step - det_loss: 0.7732 - cls_loss: 0.4761 - box_loss: 0.0059 - reg_l2_loss: 0.0633 - loss: 0.8366 - learning_rate: 0.0068 - gradient_norm: 1.4166 - val_det_loss: 0.8956 - val_cls_loss: 0.5397 - val_box_loss: 0.0071 - val_reg_l2_loss: 0.0631 - val_loss: 0.9587\n",
"Epoch 13/20\n",
"3727/3727 [==============================] - 18503s 5s/step - det_loss: 0.7630 - cls_loss: 0.4708 - box_loss: 0.0058 - reg_l2_loss: 0.0628 - loss: 0.8258 - learning_rate: 0.0052 - gradient_norm: 1.4611 - val_det_loss: 0.9174 - val_cls_loss: 0.5333 - val_box_loss: 0.0077 - val_reg_l2_loss: 0.0626 - val_loss: 0.9800\n",
"Epoch 14/20\n",
"3727/3727 [==============================] - 18504s 5s/step - det_loss: 0.7548 - cls_loss: 0.4654 - box_loss: 0.0058 - reg_l2_loss: 0.0623 - loss: 0.8171 - learning_rate: 0.0039 - gradient_norm: 1.5110 - val_det_loss: 0.8530 - val_cls_loss: 0.5123 - val_box_loss: 0.0068 - val_reg_l2_loss: 0.0621 - val_loss: 0.9151\n",
"Epoch 15/20\n",
"3727/3727 [==============================] - 18990s 5s/step - det_loss: 0.7483 - cls_loss: 0.4624 - box_loss: 0.0057 - reg_l2_loss: 0.0620 - loss: 0.8103 - learning_rate: 0.0027 - gradient_norm: 1.5508 - val_det_loss: 0.7532 - val_cls_loss: 0.4612 - val_box_loss: 0.0058 - val_reg_l2_loss: 0.0618 - val_loss: 0.8150\n",
"Epoch 16/20\n",
"3727/3727 [==============================] - 18562s 5s/step - det_loss: 0.7427 - cls_loss: 0.4586 - box_loss: 0.0057 - reg_l2_loss: 0.0617 - loss: 0.8044 - learning_rate: 0.0016 - gradient_norm: 1.5845 - val_det_loss: 0.7677 - val_cls_loss: 0.4650 - val_box_loss: 0.0061 - val_reg_l2_loss: 0.0616 - val_loss: 0.8293\n",
"Epoch 17/20\n",
"3727/3727 [==============================] - 18550s 5s/step - det_loss: 0.7350 - cls_loss: 0.4541 - box_loss: 0.0056 - reg_l2_loss: 0.0615 - loss: 0.7965 - learning_rate: 8.5270e-04 - gradient_norm: 1.6145 - val_det_loss: 0.7575 - val_cls_loss: 0.4621 - val_box_loss: 0.0059 - val_reg_l2_loss: 0.0614 - val_loss: 0.8190\n",
"Epoch 18/20\n",
"3727/3727 [==============================] - 18554s 5s/step - det_loss: 0.7340 - cls_loss: 0.4538 - box_loss: 0.0056 - reg_l2_loss: 0.0614 - loss: 0.7954 - learning_rate: 3.1704e-04 - gradient_norm: 1.6074 - val_det_loss: 0.7528 - val_cls_loss: 0.4576 - val_box_loss: 0.0059 - val_reg_l2_loss: 0.0614 - val_loss: 0.8141\n",
"Epoch 19/20\n",
"3727/3727 [==============================] - 17843s 5s/step - det_loss: 0.7312 - cls_loss: 0.4524 - box_loss: 0.0056 - reg_l2_loss: 0.0614 - loss: 0.7926 - learning_rate: 4.5510e-05 - gradient_norm: 1.6218 - val_det_loss: 0.7502 - val_cls_loss: 0.4556 - val_box_loss: 0.0059 - val_reg_l2_loss: 0.0614 - val_loss: 0.8116\n",
"Epoch 20/20\n",
"3727/3727 [==============================] - 16614s 4s/step - det_loss: 0.7300 - cls_loss: 0.4511 - box_loss: 0.0056 - reg_l2_loss: 0.0614 - loss: 0.7913 - learning_rate: 4.5510e-05 - gradient_norm: 1.6283 - val_det_loss: 0.7491 - val_cls_loss: 0.4552 - val_box_loss: 0.0059 - val_reg_l2_loss: 0.0614 - val_loss: 0.8105\n"
]
}
],
"source": [
"model = object_detector.create(train_data=train_data, model_spec=spec, batch_size=16, train_whole_model=True, epochs=20, validation_data=val_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KB4hKeerMmh4"
},
"source": [
"### Step 4. Evaluate the model with the validation data.\n",
"\n",
"After training the object detection model using the images in the training dataset, use the screenshots in the validation dataset to evaluate how the model performs against new data it has never seen before.\n",
"\n",
"As the default batch size is 64, it will take 1 step to go through the screenshots in the validation dataset.\n",
"\n",
"The evaluation metrics are same as [COCO](https://cocodataset.org/#detection-eval)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OUqEpcYwAg8L",
"outputId": "b359dc94-4906-4b5b-e022-77030b7a122f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"104/104 [==============================] - 442s 4s/step\n",
"\n"
]
},
{
"data": {
"text/plain": [
"{'AP': 0.23928128,\n",
" 'AP50': 0.34784698,\n",
" 'AP75': 0.26955733,\n",
" 'APs': 0.028244007,\n",
" 'APm': 0.12941888,\n",
" 'APl': 0.22108501,\n",
" 'ARmax1': 0.2461824,\n",
" 'ARmax10': 0.35139942,\n",
" 'ARmax100': 0.36341742,\n",
" 'ARs': 0.054406602,\n",
" 'ARm': 0.19773099,\n",
" 'ARl': 0.34407103,\n",
" 'AP_/Image': 0.16317038,\n",
" 'AP_/Input': 0.08524754,\n",
" 'AP_/On/Off Switch': 0.048255134,\n",
" 'AP_/Date Picker': 0.35675347,\n",
" 'AP_/Map View': 0.17753105,\n",
" 'AP_/List Item': 0.28585535,\n",
" 'AP_/Advertisement': 0.28919548,\n",
" 'AP_/Card': 0.20789924,\n",
" 'AP_/Checkbox': 0.1734771,\n",
" 'AP_/Drawer': 0.68573767,\n",
" 'AP_/Web View': 0.19177885,\n",
" 'AP_/Radio Button': 0.057744134,\n",
" 'AP_/Video': 0.0060326965,\n",
" 'AP_/Button Bar': 0.0023122337,\n",
" 'AP_/Text': 0.105855204,\n",
" 'AP_/Bottom Navigation': 0.3663874,\n",
" 'AP_/Toolbar': 0.5422325,\n",
" 'AP_/Number Stepper': 0.35072106,\n",
" 'AP_/Text Button': 0.2381703,\n",
" 'AP_/Pager Indicator': 0.000990099,\n",
" 'AP_/Icon': 0.34229085,\n",
" 'AP_/Slider': 0.0,\n",
" 'AP_/Modal': 0.63886,\n",
" 'AP_/Background Image': 0.46593183,\n",
" 'AP_/Multi-Tab': 0.19960245}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(val_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NARVYk9rGLIl"
},
"source": [
"### Step 5: Export as a TensorFlow Lite model with TFJS format.\n",
"\n",
"Export the trained object detection model to the TensorFlow JS format by specifying which folder you want to export the quantized model to. The default post-training quantization technique is [full integer quantization](https://www.tensorflow.org/lite/performance/post_training_integer_quant). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_u3eFxoBAiqE",
"outputId": "f3fa26d4-fdbe-4ecf-d5a0-f39b76f75032"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"'zip'��(��) ���� �Ǵ� �ܺ� ����, ������ �� �ִ� ���α׷�, �Ǵ�\n",
"��ġ ������ �ƴմϴ�.\n"
]
}
],
"source": [
"# ExportFormat.TFJS is not yet supported\n",
"# model.export(export_dir=\"./js_export/\", export_format=[ExportFormat.TFJS])\n",
"\n",
"model.export(export_dir=\"./js_export/\", export_format=[ExportFormat.SAVED_MODEL])\n",
"!zip -r ./js_export/ModelFiles.zip ./js_export/"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wnqktl45PZRy"
},
"source": [
"## (Optional) Test the detection model with TensorFlow lite\n",
"\n",
"Let's test it with an image that the model hasn't seen before to get a sense of how good the model is.\n",
"\n",
"In this example, we will extract one more tflite model and test how well the model actually works. This result may be slightly different from that of TFJS, so it is recommended to use it for reference only."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v2h_pF2-osg7",
"outputId": "45d12473-41ee-46d4-cbff-0d5d2c3f7a51"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"6627/6627 [==============================] - 27236s 4s/step\n",
"\n"
]
},
{
"data": {
"text/plain": [
"{'AP': 0.20490232,\n",
" 'AP50': 0.2911546,\n",
" 'AP75': 0.23508495,\n",
" 'APs': 0.021943321,\n",
" 'APm': 0.115160994,\n",
" 'APl': 0.18580821,\n",
" 'ARmax1': 0.2026951,\n",
" 'ARmax10': 0.26787397,\n",
" 'ARmax100': 0.26974252,\n",
" 'ARs': 0.032209698,\n",
" 'ARm': 0.15022126,\n",
" 'ARl': 0.24558628,\n",
" 'AP_/Image': 0.123515174,\n",
" 'AP_/Input': 0.07007123,\n",
" 'AP_/On/Off Switch': 0.045216367,\n",
" 'AP_/Date Picker': 0.22942775,\n",
" 'AP_/Map View': 0.1260201,\n",
" 'AP_/List Item': 0.22389112,\n",
" 'AP_/Advertisement': 0.23017684,\n",
" 'AP_/Card': 0.11379653,\n",
" 'AP_/Checkbox': 0.14040026,\n",
" 'AP_/Drawer': 0.65390044,\n",
" 'AP_/Web View': 0.14987761,\n",
" 'AP_/Radio Button': 0.03745377,\n",
" 'AP_/Video': 0.026732674,\n",
" 'AP_/Button Bar': 0.0,\n",
" 'AP_/Text': 0.07595385,\n",
" 'AP_/Bottom Navigation': 0.38563696,\n",
" 'AP_/Toolbar': 0.5175326,\n",
" 'AP_/Number Stepper': 0.3151305,\n",
" 'AP_/Text Button': 0.21079078,\n",
" 'AP_/Pager Indicator': 0.0014851486,\n",
" 'AP_/Icon': 0.3101346,\n",
" 'AP_/Slider': 0.0,\n",
" 'AP_/Modal': 0.5949396,\n",
" 'AP_/Background Image': 0.41141692,\n",
" 'AP_/Multi-Tab': 0.12905711}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.export(export_dir='.', tflite_filename='rico.tflite')\n",
"model.evaluate_tflite('rico.tflite', val_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "9ZsLQtJ1AlW_"
},
"outputs": [],
"source": [
"#@title Load the trained TFLite model and define some visualization functions\n",
"\n",
"#@markdown This code comes from the TFLite Object Detection [Raspberry Pi sample](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/raspberry_pi).\n",
"\n",
"import platform\n",
"from typing import List, NamedTuple\n",
"import json\n",
"\n",
"import cv2\n",
"\n",
"Interpreter = tf.lite.Interpreter\n",
"load_delegate = tf.lite.experimental.load_delegate\n",
"\n",
"# pylint: enable=g-import-not-at-top\n",
"\n",
"\n",
"class ObjectDetectorOptions(NamedTuple):\n",
" \"\"\"A config to initialize an object detector.\"\"\"\n",
"\n",
" enable_edgetpu: bool = False\n",
" \"\"\"Enable the model to run on EdgeTPU.\"\"\"\n",
"\n",
" label_allow_list: List[str] = None\n",
" \"\"\"The optional allow list of labels.\"\"\"\n",
"\n",
" label_deny_list: List[str] = None\n",
" \"\"\"The optional deny list of labels.\"\"\"\n",
"\n",
" max_results: int = -1\n",
" \"\"\"The maximum number of top-scored detection results to return.\"\"\"\n",
"\n",
" num_threads: int = 1\n",
" \"\"\"The number of CPU threads to be used.\"\"\"\n",
"\n",
" score_threshold: float = 0.0\n",
" \"\"\"The score threshold of detection results to return.\"\"\"\n",
"\n",
"\n",
"class Rect(NamedTuple):\n",
" \"\"\"A rectangle in 2D space.\"\"\"\n",
" left: float\n",
" top: float\n",
" right: float\n",
" bottom: float\n",
"\n",
"\n",
"class Category(NamedTuple):\n",
" \"\"\"A result of a classification task.\"\"\"\n",
" label: str\n",
" score: float\n",
" index: int\n",
"\n",
"\n",
"class Detection(NamedTuple):\n",
" \"\"\"A detected object as the result of an ObjectDetector.\"\"\"\n",
" bounding_box: Rect\n",
" categories: List[Category]\n",
"\n",
"\n",
"def edgetpu_lib_name():\n",
" \"\"\"Returns the library name of EdgeTPU in the current platform.\"\"\"\n",
" return {\n",
" 'Darwin': 'libedgetpu.1.dylib',\n",
" 'Linux': 'libedgetpu.so.1',\n",
" 'Windows': 'edgetpu.dll',\n",
" }.get(platform.system(), None)\n",
"\n",
"\n",
"class ObjectDetector:\n",
" \"\"\"A wrapper class for a TFLite object detection model.\"\"\"\n",
"\n",
" _OUTPUT_LOCATION_NAME = 'location'\n",
" _OUTPUT_CATEGORY_NAME = 'category'\n",
" _OUTPUT_SCORE_NAME = 'score'\n",
" _OUTPUT_NUMBER_NAME = 'number of detections'\n",
"\n",
" def __init__(\n",
" self,\n",
" model_path: str,\n",
" options: ObjectDetectorOptions = ObjectDetectorOptions()\n",
" ) -> None:\n",
" \"\"\"Initialize a TFLite object detection model.\n",
" Args:\n",
" model_path: Path to the TFLite model.\n",
" options: The config to initialize an object detector. (Optional)\n",
" Raises:\n",
" ValueError: If the TFLite model is invalid.\n",
" OSError: If the current OS isn't supported by EdgeTPU.\n",
" \"\"\"\n",
"\n",
" # Load metadata from model.\n",
" displayer = metadata.MetadataDisplayer.with_model_file(model_path)\n",
"\n",
" # Save model metadata for preprocessing later.\n",
" model_metadata = json.loads(displayer.get_metadata_json())\n",
" process_units = model_metadata['subgraph_metadata'][0]['input_tensor_metadata'][0]['process_units']\n",
" mean = 0.0\n",
" std = 1.0\n",
" for option in process_units:\n",
" if option['options_type'] == 'NormalizationOptions':\n",
" mean = option['options']['mean'][0]\n",
" std = option['options']['std'][0]\n",
" self._mean = mean\n",
" self._std = std\n",
"\n",
" # Load label list from metadata.\n",
" file_name = displayer.get_packed_associated_file_list()[0]\n",
" label_map_file = displayer.get_associated_file_buffer(file_name).decode()\n",
" label_list = list(filter(lambda x: len(x) > 0, label_map_file.splitlines()))\n",
" self._label_list = label_list\n",
"\n",
" # Initialize TFLite model.\n",
" if options.enable_edgetpu:\n",
" if edgetpu_lib_name() is None:\n",
" raise OSError(\"The current OS isn't supported by Coral EdgeTPU.\")\n",
" interpreter = Interpreter(\n",
" model_path=model_path,\n",
" experimental_delegates=[load_delegate(edgetpu_lib_name())],\n",
" num_threads=options.num_threads)\n",
" else:\n",
" interpreter = Interpreter(\n",
" model_path=model_path, num_threads=options.num_threads)\n",
"\n",
" interpreter.allocate_tensors()\n",
" input_detail = interpreter.get_input_details()[0]\n",
"\n",
" # From TensorFlow 2.6, the order of the outputs become undefined.\n",
" # Therefore we need to sort the tensor indices of TFLite outputs and to know\n",
" # exactly the meaning of each output tensor. For example, if\n",
" # output indices are [601, 599, 598, 600], tensor names and indices aligned\n",
" # are:\n",
" # - location: 598\n",
" # - category: 599\n",
" # - score: 600\n",
" # - detection_count: 601\n",
" # because of the op's ports of TFLITE_DETECTION_POST_PROCESS\n",
" # (https://github.com/tensorflow/tensorflow/blob/a4fe268ea084e7d323133ed7b986e0ae259a2bc7/tensorflow/lite/kernels/detection_postprocess.cc#L47-L50).\n",
" sorted_output_indices = sorted(\n",
" [output['index'] for output in interpreter.get_output_details()])\n",
" self._output_indices = {\n",
" self._OUTPUT_LOCATION_NAME: sorted_output_indices[0],\n",
" self._OUTPUT_CATEGORY_NAME: sorted_output_indices[1],\n",
" self._OUTPUT_SCORE_NAME: sorted_output_indices[2],\n",
" self._OUTPUT_NUMBER_NAME: sorted_output_indices[3],\n",
" }\n",
"\n",
" self._input_size = input_detail['shape'][2], input_detail['shape'][1]\n",
" self._is_quantized_input = input_detail['dtype'] == np.uint8\n",
" self._interpreter = interpreter\n",
" self._options = options\n",
"\n",
" def detect(self, input_image: np.ndarray) -> List[Detection]:\n",
" \"\"\"Run detection on an input image.\n",
" Args:\n",
" input_image: A [height, width, 3] RGB image. Note that height and width\n",
" can be anything since the image will be immediately resized according\n",
" to the needs of the model within this function.\n",
" Returns:\n",
" A Person instance.\n",
" \"\"\"\n",
" image_height, image_width, _ = input_image.shape\n",
"\n",
" input_tensor = self._preprocess(input_image)\n",
"\n",
" self._set_input_tensor(input_tensor)\n",
" self._interpreter.invoke()\n",
"\n",
" # Get all output details\n",
" boxes = self._get_output_tensor(self._OUTPUT_LOCATION_NAME)\n",
" classes = self._get_output_tensor(self._OUTPUT_CATEGORY_NAME)\n",
" scores = self._get_output_tensor(self._OUTPUT_SCORE_NAME)\n",
" count = int(self._get_output_tensor(self._OUTPUT_NUMBER_NAME))\n",
"\n",
" return self._postprocess(boxes, classes, scores, count, image_width,\n",
" image_height)\n",
"\n",
" def _preprocess(self, input_image: np.ndarray) -> np.ndarray:\n",
" \"\"\"Preprocess the input image as required by the TFLite model.\"\"\"\n",
"\n",
" # Resize the input\n",
" input_tensor = cv2.resize(input_image, self._input_size)\n",
"\n",
" # Normalize the input if it's a float model (aka. not quantized)\n",
" if not self._is_quantized_input:\n",
" input_tensor = (np.float32(input_tensor) - self._mean) / self._std\n",
"\n",
" # Add batch dimension\n",
" input_tensor = np.expand_dims(input_tensor, axis=0)\n",
"\n",
" return input_tensor\n",
"\n",
" def _set_input_tensor(self, image):\n",
" \"\"\"Sets the input tensor.\"\"\"\n",
" tensor_index = self._interpreter.get_input_details()[0]['index']\n",
" input_tensor = self._interpreter.tensor(tensor_index)()[0]\n",
" input_tensor[:, :] = image\n",
"\n",
" def _get_output_tensor(self, name):\n",
" \"\"\"Returns the output tensor at the given index.\"\"\"\n",
" output_index = self._output_indices[name]\n",
" tensor = np.squeeze(self._interpreter.get_tensor(output_index))\n",
" return tensor\n",
"\n",
" def _postprocess(self, boxes: np.ndarray, classes: np.ndarray,\n",
" scores: np.ndarray, count: int, image_width: int,\n",
" image_height: int) -> List[Detection]:\n",
" \"\"\"Post-process the output of TFLite model into a list of Detection objects.\n",
" Args:\n",
" boxes: Bounding boxes of detected objects from the TFLite model.\n",
" classes: Class index of the detected objects from the TFLite model.\n",
" scores: Confidence scores of the detected objects from the TFLite model.\n",
" count: Number of detected objects from the TFLite model.\n",
" image_width: Width of the input image.\n",
" image_height: Height of the input image.\n",
" Returns:\n",
" A list of Detection objects detected by the TFLite model.\n",
" \"\"\"\n",
" results = []\n",
"\n",
" # Parse the model output into a list of Detection entities.\n",
" for i in range(count):\n",
" if scores[i] >= self._options.score_threshold:\n",
" y_min, x_min, y_max, x_max = boxes[i]\n",
" bounding_box = Rect(\n",
" top=int(y_min * image_height),\n",
" left=int(x_min * image_width),\n",
" bottom=int(y_max * image_height),\n",
" right=int(x_max * image_width))\n",
" class_id = int(classes[i])\n",
" category = Category(\n",
" score=scores[i],\n",
" label=self._label_list[class_id], # 0 is reserved for background\n",
" index=class_id)\n",
" result = Detection(bounding_box=bounding_box, categories=[category])\n",
" results.append(result)\n",
"\n",
" # Sort detection results by score ascending\n",
" sorted_results = sorted(\n",
" results,\n",
" key=lambda detection: detection.categories[0].score,\n",
" reverse=True)\n",
"\n",
" # Filter out detections in deny list\n",
" filtered_results = sorted_results\n",
" if self._options.label_deny_list is not None:\n",
" filtered_results = list(\n",
" filter(\n",
" lambda detection: detection.categories[0].label not in self.\n",
" _options.label_deny_list, filtered_results))\n",
"\n",
" # Keep only detections in allow list\n",
" if self._options.label_allow_list is not None:\n",
" filtered_results = list(\n",
" filter(\n",
" lambda detection: detection.categories[0].label in self._options.\n",
" label_allow_list, filtered_results))\n",
"\n",
" # Only return maximum of max_results detection.\n",
" if self._options.max_results > 0:\n",
" result_count = min(len(filtered_results), self._options.max_results)\n",
" filtered_results = filtered_results[:result_count]\n",
"\n",
" return filtered_results\n",
"\n",
"\n",
"_MARGIN = 10 # pixels\n",
"_ROW_SIZE = 10 # pixels\n",
"_FONT_SIZE = 1\n",
"_FONT_THICKNESS = 1\n",
"_TEXT_COLOR = (0, 0, 255) # red\n",
"\n",
"\n",
"def visualize(\n",
" image: np.ndarray,\n",
" detections: List[Detection],\n",
") -> np.ndarray:\n",
" \"\"\"Draws bounding boxes on the input image and return it.\n",
" Args:\n",
" image: The input RGB image.\n",
" detections: The list of all \"Detection\" entities to be visualize.\n",
" Returns:\n",
" Image with bounding boxes.\n",
" \"\"\"\n",
" for detection in detections:\n",
" # Draw bounding_box\n",
" start_point = detection.bounding_box.left, detection.bounding_box.top\n",
" end_point = detection.bounding_box.right, detection.bounding_box.bottom\n",
" cv2.rectangle(image, start_point, end_point, _TEXT_COLOR, 3)\n",
"\n",
" # Draw label and score\n",
" category = detection.categories[0]\n",
" class_name = category.label\n",
" probability = round(category.score, 2)\n",
" result_text = class_name + ' (' + str(probability) + ')'\n",
" text_location = (_MARGIN + detection.bounding_box.left,\n",
" _MARGIN + _ROW_SIZE + detection.bounding_box.top)\n",
" cv2.putText(image, result_text, text_location, cv2.FONT_HERSHEY_PLAIN,\n",
" _FONT_SIZE, _TEXT_COLOR, _FONT_THICKNESS)\n",
"\n",
" return image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 401
},
"id": "1t1z2fKlAoB0",
"outputId": "97e3d3d1-3168-4f6c-a754-ca1e6cf4f4a7"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"'wget'��(��) ���� �Ǵ� �ܺ� ����, ������ �� �ִ� ���α׷�, �Ǵ�\n",
"��ġ ������ �ƴմϴ�.\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<PIL.Image.Image image mode=RGB size=512x338 at 0x1F02F6812B0>"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#@title Run object detection and show the detection results\n",
"\n",
"from PIL import Image\n",
"\n",
"INPUT_IMAGE_URL = \"https://miro.medium.com/max/1400/1*besTuD-m9aktHEJ2VRhCAA.png\" #@param {type:\"string\"}\n",
"DETECTION_THRESHOLD = 0.3 #@param {type:\"number\"}\n",
"TFLITE_MODEL_PATH = \"rico.tflite\" #@param {type:\"string\"}\n",
"\n",
"TEMP_FILE = 'image.png'\n",
"\n",
"!wget -q -O $TEMP_FILE $INPUT_IMAGE_URL\n",
"image = Image.open(TEMP_FILE).convert('RGB')\n",
"image.thumbnail((512, 512), Image.ANTIALIAS)\n",
"image_np = np.asarray(image)\n",
"\n",
"# Load the TFLite model\n",
"options = ObjectDetectorOptions(\n",
" num_threads=4,\n",
" score_threshold=DETECTION_THRESHOLD,\n",
")\n",
"detector = ObjectDetector(model_path=TFLITE_MODEL_PATH, options=options)\n",
"\n",
"# Run object detection estimation using the model.\n",
"detections = detector.detect(image_np)\n",
"\n",
"# Draw keypoints and edges on input image\n",
"image_np = visualize(image_np, detections)\n",
"\n",
"# Show the detection result\n",
"Image.fromarray(image_np)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "Model Maker Object Detection for RICO.ipynb",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment