Created
August 12, 2019 18:50
-
-
Save ayyucedemirbas/6c2d6bd9324834432df02e8083be9031 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2019 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
r"""A demo to demonstrate transfer learning for classification model. | |
Args: | |
- model_path | |
Path of base model, e.g., | |
'test_data/imprinting/mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite' | |
- data | |
Path to the directory of data set, e.g., 'test_data/open_image_v4_subset'. | |
Please notice that you need to run 'test_data/download_imprinting_test_data.sh' | |
to generate the data set. | |
- output | |
Output name of the trained model. By default it is | |
'[model_name]_retrained.tflite'. | |
- test_ratio | |
The ratio of images used for test. By default it's 0.25. | |
- keep_classes | |
Bool, whether to keep base model classes. It is False if not set. | |
Steps: | |
- Under the parent directory `edgetpu/`. | |
- Prepares the data set for transfer learning. | |
Run 'bash test_data/download_imprinting_test_data.sh' to download the data | |
we prepared. There are 10 categories, 20 images for each category. 200 | |
images in total. | |
- Run this demo to create the new classification model. | |
python3 edgetpu/demo/classification_transfer_learning.py | |
--model_path='test_data/imprinting/mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite' | |
--data='test_data/open_image_v4_subset' | |
--output='my_model.tflite' | |
- Verify with Classification model. | |
'my_model.tflite' and 'my_model.txt'(labels file) produced by last step can | |
be treated as same as a normal classification model. You can use | |
ClassificationEngine for verification or further development. | |
python3 edgetpu/demo/classify_image.py --model='my_model.tflite' \ | |
--label='my_model.txt' --image='test_data/cat.bmp' | |
""" | |
import argparse | |
import os | |
from edgetpu.basic.basic_engine import BasicEngine | |
from edgetpu.classification.engine import ClassificationEngine | |
from edgetpu.learn.imprinting.engine import ImprintingEngine | |
import numpy as np | |
from PIL import Image | |
def _ReadData(path, test_ratio): | |
"""Parses data from given directory, split them into two sets. | |
Args: | |
path: string, path of the data set. Images are stored in sub-directory | |
named by category. | |
test_ratio: float in (0,1), ratio of data used for testing. | |
Returns: | |
(train_set, test_set), A tuple of two dicts. Keys are the categories and | |
values are lists of image file names. | |
""" | |
train_set = {} | |
test_set = {} | |
for category in os.listdir(path): | |
category_dir = os.path.join(path, category) | |
if os.path.isdir(category_dir): | |
images = [f for f in os.listdir(category_dir) | |
if os.path.isfile(os.path.join(category_dir, f))] | |
if images: | |
k = max(int(test_ratio * len(images)), 1) | |
test_set[category] = images[:k] | |
assert test_set[category], 'No images to test [{}]'.format(category) | |
train_set[category] = images[k:] | |
assert train_set[category], 'No images to train [{}]'.format(category) | |
return train_set, test_set | |
def _PrepareImages(image_list, directory, shape): | |
"""Reads images and converts them to numpy array with given shape. | |
Args: | |
image_list: a list of strings storing file names. | |
directory: string, path of directory storing input images. | |
shape: a 2-D tuple represents the shape of required input tensor. | |
Returns: | |
A list of numpy.array. | |
""" | |
ret = [] | |
for filename in image_list: | |
with Image.open(os.path.join(directory, filename)) as img: | |
img = img.convert('RGB') | |
img = img.resize(shape, Image.NEAREST) | |
ret.append(np.asarray(img).flatten()) | |
return np.array(ret) | |
def _SaveLabels(labels, model_path): | |
"""Output labels as a txt file. | |
Args: | |
labels: {int : string}, map between label id and label. | |
model_path: string, path of the model. | |
""" | |
label_file_name = model_path.replace('.tflite', '.txt') | |
with open(label_file_name, 'w') as f: | |
for label_id, label in labels.items(): | |
f.write(str(label_id) + ' ' + label + '\n') | |
print('Labels file saved as :', label_file_name) | |
def _GetRequiredShape(model_path): | |
"""Gets image shape required by model. | |
Args: | |
model_path: string, path of the model. | |
Returns: | |
(width, height). | |
""" | |
tmp = BasicEngine(model_path) | |
input_tensor = tmp.get_input_tensor_shape() | |
return (input_tensor[2], input_tensor[1]) | |
def _GetOutputNumberClasses(model_path): | |
"""Gets the number of output classes. | |
Args: | |
model_path: string, path of the model. | |
Returns: | |
int, number of the output classes. | |
""" | |
tmp = BasicEngine(model_path) | |
assert tmp.get_num_of_output_tensors() == 1 | |
return tmp.total_output_array_size() | |
def _ParseArgs(): | |
"""Parses args, set default values if it's not passed. | |
Returns: | |
Object with attributes. Each attribute represents an argument. | |
""" | |
print('---------------------- Args ----------------------') | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--model_path', help='Path to the model path.', required=True) | |
parser.add_argument( | |
'--data', help=('Path to the training set, images are stored' | |
'under sub-directory named by category.'), required=True) | |
parser.add_argument( | |
'--output', help='Name of the trained model.') | |
parser.add_argument( | |
'--test_ratio', type=float, | |
help='Float number in (0,1), ratio of data used for test data.') | |
parser.add_argument( | |
'--keep_classes', action='store_true', | |
help='Whether to keep base model classes.') | |
args = parser.parse_args() | |
if not args.output: | |
model_name = os.path.basename(args.model_path) | |
args.output = model_name.replace('.tflite', '_retrained.tflite') | |
print('Output path :', args.output) | |
# By default, choose 25% data for test. | |
if not args.test_ratio: | |
args.test_ratio = 0.25 | |
assert args.test_ratio > 0 | |
assert args.test_ratio < 1.0 | |
print('Ratio of test images: {:.0%}'.format(args.test_ratio)) | |
return args | |
def main(): | |
args = _ParseArgs() | |
print('--------------- Parsing data set -----------------') | |
print('Dataset path:', args.data) | |
train_set, test_set = _ReadData(args.data, args.test_ratio) | |
print('Image list successfully parsed! Category Num = ', len(train_set)) | |
shape = _GetRequiredShape(args.model_path) | |
print('---------------- Processing training data ----------------') | |
print('This process may take more than 30 seconds.') | |
train_input = [] | |
labels_map = {} | |
for class_id, (category, image_list) in enumerate(train_set.items()): | |
print('Processing category:', category) | |
train_input.append( | |
_PrepareImages( | |
image_list, os.path.join(args.data, category), shape) | |
) | |
labels_map[class_id] = category | |
print('---------------- Start training -----------------') | |
engine = ImprintingEngine(args.model_path, keep_classes=args.keep_classes) | |
engine.TrainAll(train_input) | |
print('---------------- Training finished! -----------------') | |
engine.SaveModel(args.output) | |
print('Model saved as : ', args.output) | |
_SaveLabels(labels_map, args.output) | |
print('------------------ Start evaluating ------------------') | |
engine = ClassificationEngine(args.output) | |
top_k = 5 | |
correct = [0] * top_k | |
wrong = [0] * top_k | |
for category, image_list in test_set.items(): | |
print('Evaluating category [', category, ']') | |
for img_name in image_list: | |
img = Image.open(os.path.join(args.data, category, img_name)) | |
candidates = engine.ClassifyWithImage(img, threshold=0.1, top_k=top_k) | |
recognized = False | |
for i in range(top_k): | |
if i < len(candidates) and labels_map[candidates[i][0]] == category: | |
recognized = True | |
if recognized: | |
correct[i] = correct[i] + 1 | |
else: | |
wrong[i] = wrong[i] + 1 | |
print('---------------- Evaluation result -----------------') | |
for i in range(top_k): | |
print('Top {} : {:.0%}'.format(i+1, correct[i] / (correct[i] + wrong[i]))) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment