Skip to content

Instantly share code, notes, and snippets.

@lezwon
Created June 6, 2021 09:05
Show Gist options
  • Save lezwon/53109eafb3105eea756e8d92d11a58ab to your computer and use it in GitHub Desktop.
Save lezwon/53109eafb3105eea756e8d92d11a58ab to your computer and use it in GitHub Desktop.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""label_image for tflite."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import time
import numpy as np
from PIL import Image
import tflite_runtime.interpreter as tflite
class Classifier:
def __init__(self, model_path, label_file, input_mean=127.5, input_std=127.5):
self.input_mean = input_mean
self.input_std = input_std
self.labels = self.load_labels(label_file)
self.interpreter = tflite.Interpreter(model_path=model_path)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
# check the type of the input tensor
self.floating_model = self.input_details[0]['dtype'] == np.float32
# NxHxWxC, H:1, W:2
self.height = self.input_details[0]['shape'][1]
self.width = self.input_details[0]['shape'][2]
def load_labels(self, filename):
with open(filename, 'r') as f:
return [line.strip() for line in f.readlines()]
def infer(self, image):
"""
Infers the image in tf lite
Args:
image (string): Path of the image
Returns:
label (string): Detected label of the Image
"""
img = Image.open(image).resize((self.width, self.height))
# add N dim
input_data = np.expand_dims(img, axis=0)
if self.floating_model:
input_data = (np.float32(input_data) - self.input_mean) / self.input_std
self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
self.interpreter.invoke()
output_data = self.interpreter.get_tensor(self.output_details[0]['index'])
results = np.squeeze(output_data)
i = results.argmax()
idx = self.labels[i].index(" ")
return self.labels[i][idx+1:]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment