Last active
October 18, 2018 06:39
-
-
Save pboos/117df482bebb16587b909290599467c5 to your computer and use it in GitHub Desktop.
Tensorflow Lite Android
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division, print_function, absolute_import | |
# library for optmising inference | |
from tensorflow.python.tools import optimize_for_inference_lib | |
from tensorflow.python.tools import freeze_graph | |
import tensorflow as tf | |
# Higher level API tflearn | |
import tflearn | |
from tflearn.data_utils import shuffle, to_categorical | |
from tflearn.layers.core import input_data, dropout, fully_connected | |
from tflearn.layers.conv import conv_2d, max_pool_2d | |
from tflearn.layers.estimator import regression | |
from tflearn.data_preprocessing import ImagePreprocessing | |
from tflearn.data_augmentation import ImageAugmentation | |
from tflearn.data_utils import image_preloader | |
import numpy as np | |
# Data loading and preprocessing | |
#helper functions to download the CIFAR 10 data and load them dynamically | |
# from tflearn.datasets import cifar10 | |
# (X, Y), (X_test, Y_test) = cifar10.load_data() | |
# X, Y = shuffle(X, Y) | |
# Y = to_categorical(Y,10) | |
# Y_test = to_categorical(Y_test,10) | |
IMAGE_FOLDER = 'datasets/button_images' | |
TRAIN_DATA = 'datasets/training_data.txt' | |
TEST_DATA = 'datasets/test_data.txt' | |
VALIDATION_DATA = 'datasets/validation_data.txt' | |
IMAGE_SIZE=24 | |
train_proportion=0.7 | |
test_proportion=0.2 | |
validation_proportion=0.1 | |
import glob | |
import os.path | |
import random | |
import math | |
# classes = filter(lambda f: not f.startswith('.'), os.listdir(IMAGE_FOLDER)) | |
# classes.sort(key=str.lower) | |
classes = ['close', 'pause', 'play', 'stop', 'other'] | |
nrOfClasses = len(classes) | |
print('Classes: ' + str(classes)) | |
filesDepth2 = glob.glob(IMAGE_FOLDER + '/*/*') | |
images = filter(lambda f: not os.path.isdir(f), filesDepth2) | |
random.shuffle(images) | |
dir_path = os.path.dirname(os.path.realpath(__file__)) | |
def createDataFile(images, skipPercentage, percentage, dataFile): | |
total = len(images) | |
fr = open(dataFile, 'w') | |
start = int(math.ceil(skipPercentage * total)) | |
end = int(math.ceil((skipPercentage + percentage) * total)) | |
images_subset = images[start:end] | |
for filename in images_subset: | |
startClass = len(IMAGE_FOLDER) + 1 | |
endClass = filename.index('/', startClass) | |
className = filename[startClass:endClass] | |
fullPath = dir_path + '/' + filename | |
classNameInt = classes.index(className) if className in classes else -1 | |
if classNameInt != -1: | |
fr.write(fullPath + ' ' + str(classNameInt) + '\n') | |
fr.close() | |
createDataFile(images, 0.0, 0.7, TRAIN_DATA) | |
createDataFile(images, 0.7, 0.9, TEST_DATA) | |
createDataFile(images, 0.9, 1.0, VALIDATION_DATA) | |
# TODO maybe use grayscale=True | |
X_train, Y_train = image_preloader(TRAIN_DATA, image_shape=(IMAGE_SIZE,IMAGE_SIZE),mode='file', categorical_labels=True,normalize=True) | |
X_test, Y_test = image_preloader(TEST_DATA, image_shape=(IMAGE_SIZE,IMAGE_SIZE),mode='file', categorical_labels=True,normalize=True) | |
X_val, Y_val = image_preloader(VALIDATION_DATA, image_shape=(IMAGE_SIZE,IMAGE_SIZE),mode='file', categorical_labels=True,normalize=True) | |
# input image | |
x = tf.placeholder(tf.float32,shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3] , name="ipnode") | |
# input class | |
y_ = tf.placeholder(tf.float32,shape=[None, nrOfClasses] , name='input_class') | |
# AlexNet architecture | |
input_layer = x | |
network = conv_2d(input_layer, IMAGE_SIZE, 3, activation='relu') | |
network = max_pool_2d(network, 2) | |
network = conv_2d(network, 64, 3, activation='relu') | |
network = conv_2d(network, 64, 3, activation='relu') | |
network = max_pool_2d(network, 2) | |
network = fully_connected(network, 512, activation='relu') | |
network = fully_connected(network, nrOfClasses, activation='linear') | |
y_predicted = tf.nn.softmax(network , name="opnode") | |
#loss function | |
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_predicted+np.exp(-nrOfClasses)), reduction_indices=[1])) | |
#optimiser - | |
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) | |
#calculating accuracy of our model | |
correct_prediction = tf.equal(tf.argmax(y_predicted,1), tf.argmax(y_,1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
#TensorFlow session | |
sess = tf.Session() | |
#initialising variables | |
init = tf.global_variables_initializer() | |
sess.run(init) | |
#tensorboard for better visualisation | |
writer =tf.summary.FileWriter('tensorboard/', sess.graph) | |
epoch=50 # run for more iterations according your hardware's power | |
#change batch size according to your hardware's power. For GPU's use batch size in powers of 2 like 2,4,8,16... | |
batch_size=32 | |
no_itr_per_epoch=len(X_train)//batch_size | |
n_test=len(X_test) #number of test samples | |
# Commencing training process | |
for iteration in range(epoch): | |
print("Iteration no: {} ".format(iteration)) | |
previous_batch=0 | |
# Do our mini batches: | |
for i in range(no_itr_per_epoch): | |
current_batch=previous_batch+batch_size | |
x_input=X_train[previous_batch:current_batch] | |
x_images=np.reshape(x_input,[batch_size,IMAGE_SIZE,IMAGE_SIZE,3]) | |
y_input=Y_train[previous_batch:current_batch] | |
y_label=np.reshape(y_input,[batch_size,nrOfClasses]) | |
previous_batch=previous_batch+batch_size | |
_,loss=sess.run([train_step, cross_entropy], feed_dict={x: x_images,y_: y_label}) | |
#if i % 100==0 : | |
#print ("Training loss : {}" .format(loss)) | |
x_test_images=np.reshape(X_test[0:n_test],[n_test,IMAGE_SIZE,IMAGE_SIZE,3]) | |
y_test_labels=np.reshape(Y_test[0:n_test],[n_test,nrOfClasses]) | |
Accuracy_test=sess.run(accuracy, | |
feed_dict={ | |
x: x_test_images , | |
y_: y_test_labels | |
}) | |
# Accuracy of the test set | |
Accuracy_test=round(Accuracy_test*100,2) | |
print("Accuracy :: Test_set {} % " .format(Accuracy_test)) | |
##################### | |
##################### | |
# saving the graph | |
saver = tf.train.Saver() | |
model_directory='model_files/' | |
tf.train.write_graph(sess.graph_def, model_directory, 'savegraph.pbtxt') | |
saver.save(sess, 'model_files/model.ckpt') | |
sess.close() | |
################# | |
## Freeze the graph | |
################# | |
MODEL_NAME = 'button' | |
input_graph_path = 'model_files/savegraph.pbtxt' | |
checkpoint_path = 'model_files/model.ckpt' | |
input_saver_def_path = "" | |
input_binary = False | |
input_node_names = "ipnode" | |
output_node_names = "opnode" | |
input_nodes = tf.placeholder(tf.float32,shape=[1, IMAGE_SIZE, IMAGE_SIZE, 3], name=input_node_names) | |
output_nodes = y_predicted | |
restore_op_name = "save/restore_all" | |
filename_tensor_name = "save/Const:0" | |
output_frozen_graph_name = 'model_files/model_frozen_' + MODEL_NAME + '.pb' | |
output_optimized_graph_name = 'model_files/model_optimized_' + MODEL_NAME + '.pb' | |
output_converted_graph_name = 'model_files/model_converted_' + MODEL_NAME + '.tflite' | |
clear_devices = True | |
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path, | |
input_binary, checkpoint_path, output_node_names, | |
restore_op_name, filename_tensor_name, | |
output_frozen_graph_name, clear_devices, "") | |
################# | |
## optimize graph | |
################# | |
input_graph_def = tf.GraphDef() | |
with tf.gfile.Open(output_frozen_graph_name, "r") as f: | |
data = f.read() | |
input_graph_def.ParseFromString(data) | |
output_graph_def = optimize_for_inference_lib.optimize_for_inference( | |
input_graph_def, | |
[input_node_names], # an array of the input node(s) | |
[output_node_names], # an array of output nodes | |
tf.float32.as_datatype_enum) | |
# save optimized graph | |
f = tf.gfile.FastGFile(output_optimized_graph_name, "w") | |
f.write(output_graph_def.SerializeToString()) | |
################# | |
## convert graph | |
################# | |
from subprocess import call | |
call([ | |
"toco", | |
"--graph_def_file=" + output_frozen_graph_name, | |
"--input_format=TENSORFLOW_GRAPHDEF", | |
"--output_format=TFLITE", | |
"--output_file=" + output_converted_graph_name, | |
"--input_shape=1," + str(IMAGE_SIZE) + "," + str(IMAGE_SIZE) + ",3", | |
"--input_type=FLOAT", | |
"--input_array=" + input_node_names, | |
"--output_array=" + output_node_names, | |
"--inference_type=FLOAT", | |
"--inference_input_type=FLOAT" | |
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
# Load TFLite model and allocate tensors. | |
interpreter = tf.contrib.lite.Interpreter(model_path="model_files/model_converted_button.tflite") | |
interpreter.allocate_tensors() | |
# Get input and output tensors. | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
# Test model on random input data. | |
input_shape = input_details[0]['shape'] | |
output_shape = output_details[0]['shape'] | |
# change the following line to feed into your own data. | |
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) | |
interpreter.set_tensor(input_details[0]['index'], input_data) | |
interpreter.invoke() | |
output_data = interpreter.get_tensor(output_details[0]['index']) | |
print('input name: ' + str(input_details[0]['name'])) | |
print('input shape: ' + str(input_shape)) | |
print('output name: ' + str(output_details[0]['name'])) | |
print('output shape: ' + str(output_shape)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import sys | |
import time | |
import numpy as np | |
import tensorflow as tf | |
def load_graph(model_file): | |
graph = tf.Graph() | |
graph_def = tf.GraphDef() | |
with open(model_file, "rb") as f: | |
graph_def.ParseFromString(f.read()) | |
with graph.as_default(): | |
tf.import_graph_def(graph_def) | |
return graph | |
def read_tensor_from_image_file(file_name, input_height=299, input_width=299, | |
input_mean=0, input_std=255): | |
print(input_mean) | |
print(input_std) | |
input_name = "file_reader" | |
output_name = "normalized" | |
file_reader = tf.read_file(file_name, input_name) | |
if file_name.endswith(".png"): | |
image_reader = tf.image.decode_png(file_reader, channels = 3, | |
name='png_reader') | |
elif file_name.endswith(".gif"): | |
image_reader = tf.squeeze(tf.image.decode_gif(file_reader, | |
name='gif_reader')) | |
elif file_name.endswith(".bmp"): | |
image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader') | |
else: | |
image_reader = tf.image.decode_jpeg(file_reader, channels = 3, | |
name='jpeg_reader') | |
float_caster = tf.cast(image_reader, tf.float32) | |
dims_expander = tf.expand_dims(float_caster, 0) | |
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) | |
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) | |
sess = tf.Session() | |
result = sess.run(normalized) | |
return result | |
def load_labels(label_file): | |
label = [] | |
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() | |
for l in proto_as_ascii_lines: | |
label.append(l.rstrip()) | |
return label | |
if __name__ == "__main__": | |
file_name = "datasets/button_images/other/1.jpeg" | |
model_file = "model_files/model_optimized_button.pb" | |
label_file = "labels.txt" | |
input_height = 24 | |
input_width = 24 | |
input_mean = 0 | |
input_std = 255 | |
input_layer = "ipnode" | |
output_layer = "opnode" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--image", help="image to be processed") | |
parser.add_argument("--graph", help="graph/model to be executed") | |
parser.add_argument("--labels", help="name of file containing labels") | |
parser.add_argument("--input_height", type=int, help="input height") | |
parser.add_argument("--input_width", type=int, help="input width") | |
parser.add_argument("--input_mean", type=int, help="input mean") | |
parser.add_argument("--input_std", type=int, help="input std") | |
parser.add_argument("--input_layer", help="name of input layer") | |
parser.add_argument("--output_layer", help="name of output layer") | |
args = parser.parse_args() | |
if args.graph: | |
model_file = args.graph | |
if args.image: | |
file_name = args.image | |
if args.labels: | |
label_file = args.labels | |
if args.input_height: | |
input_height = args.input_height | |
if args.input_width: | |
input_width = args.input_width | |
if args.input_mean is not None: | |
input_mean = args.input_mean | |
if args.input_std: | |
input_std = args.input_std | |
if args.input_layer: | |
input_layer = args.input_layer | |
if args.output_layer: | |
output_layer = args.output_layer | |
graph = load_graph(model_file) | |
t = read_tensor_from_image_file(file_name, | |
input_height=input_height, | |
input_width=input_width, | |
input_mean=input_mean, | |
input_std=input_std) | |
input_name = "import/" + input_layer | |
output_name = "import/" + output_layer | |
input_operation = graph.get_operation_by_name(input_name) | |
output_operation = graph.get_operation_by_name(output_name) | |
with tf.Session(graph=graph) as sess: | |
start = time.time() | |
results = sess.run(output_operation.outputs[0], | |
{input_operation.outputs[0]: t}) | |
end=time.time() | |
results = np.squeeze(results) | |
top_k = results.argsort()[-5:][::-1] | |
labels = load_labels(label_file) | |
print('\nEvaluation time (1-image): {:.3f}s\n'.format(end-start)) | |
template = "{} (score={:0.5f})" | |
for i in top_k: | |
print(template.format(labels[i], results[i])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import sys | |
import time | |
import numpy as np | |
import tensorflow as tf | |
def read_tensor_from_image_file(file_name, input_height=299, input_width=299, | |
input_mean=0, input_std=255): | |
input_name = "file_reader" | |
output_name = "normalized" | |
file_reader = tf.read_file(file_name, input_name) | |
if file_name.endswith(".png"): | |
image_reader = tf.image.decode_png(file_reader, channels = 3, | |
name='png_reader') | |
elif file_name.endswith(".gif"): | |
image_reader = tf.squeeze(tf.image.decode_gif(file_reader, | |
name='gif_reader')) | |
elif file_name.endswith(".bmp"): | |
image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader') | |
else: | |
image_reader = tf.image.decode_jpeg(file_reader, channels = 3, | |
name='jpeg_reader') | |
float_caster = tf.cast(image_reader, tf.float32) | |
dims_expander = tf.expand_dims(float_caster, 0) | |
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) | |
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) | |
sess = tf.Session() | |
result = sess.run(normalized) | |
return result | |
def load_labels(label_file): | |
label = [] | |
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() | |
for l in proto_as_ascii_lines: | |
label.append(l.rstrip()) | |
return label | |
if __name__ == "__main__": | |
file_name = "datasets/button_images/other/1.jpeg" | |
model_file = "model_files/model_converted_button.tflite" | |
label_file = "labels.txt" | |
input_height = 24 | |
input_width = 24 | |
input_mean = 0 | |
input_std = 255 | |
input_layer = "ipnode" | |
output_layer = "opnode" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--image", help="image to be processed") | |
parser.add_argument("--graph", help="graph/model to be executed") | |
parser.add_argument("--labels", help="name of file containing labels") | |
parser.add_argument("--input_height", type=int, help="input height") | |
parser.add_argument("--input_width", type=int, help="input width") | |
parser.add_argument("--input_mean", type=int, help="input mean") | |
parser.add_argument("--input_std", type=int, help="input std") | |
args = parser.parse_args() | |
if args.graph: | |
model_file = args.graph | |
if args.image: | |
file_name = args.image | |
if args.labels: | |
label_file = args.labels | |
if args.input_height: | |
input_height = args.input_height | |
if args.input_width: | |
input_width = args.input_width | |
if args.input_mean is not None: | |
input_mean = args.input_mean | |
if args.input_std: | |
input_std = args.input_std | |
# Load TFLite model and allocate tensors. | |
interpreter = tf.contrib.lite.Interpreter(model_path=model_file) | |
interpreter.allocate_tensors() | |
# Get input and output tensors. | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
# Test model on random input data. | |
input_shape = input_details[0]['shape'] | |
t = read_tensor_from_image_file(file_name, | |
input_height=input_height, | |
input_width=input_width, | |
input_mean=input_mean, | |
input_std=input_std) | |
interpreter.set_tensor(input_details[0]['index'], t) | |
interpreter.invoke() | |
results = interpreter.get_tensor(output_details[0]['index']) | |
results = np.squeeze(results) | |
top_k = results.argsort()[-5:][::-1] | |
labels = load_labels(label_file) | |
template = "{} (score={:0.5f})" | |
for i in top_k: | |
print(template.format(labels[i], results[i])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.carecon.tensorflow | |
import android.content.Context | |
import android.graphics.Bitmap | |
import android.graphics.BitmapFactory | |
import android.graphics.Canvas | |
import android.graphics.Color | |
import android.graphics.ColorMatrix | |
import android.graphics.ColorMatrixColorFilter | |
import android.graphics.Paint | |
import android.os.Bundle | |
import android.support.v7.app.AppCompatActivity | |
import kotlinx.android.synthetic.main.activity_main.* | |
import org.tensorflow.lite.Interpreter | |
import java.io.FileInputStream | |
import java.io.IOException | |
import java.io.InputStream | |
import java.nio.ByteBuffer | |
import java.nio.ByteOrder | |
import java.nio.MappedByteBuffer | |
import java.nio.channels.FileChannel | |
class MainActivity : AppCompatActivity() { | |
override fun onCreate(savedInstanceState: Bundle?) { | |
super.onCreate(savedInstanceState) | |
setContentView(R.layout.activity_main) | |
runTensorflowLite() | |
} | |
private val classifier: Classifier by lazy { TFMobileClassifier(this, modelFilename = "model.tflite") } | |
private fun runTensorflowLite() { | |
Thread(Runnable { | |
val start = System.currentTimeMillis() | |
val bitmap = getBitmapFromAsset(this, "button.jpg") ?: return@Runnable | |
val scaled = Bitmap.createScaledBitmap(bitmap, 24, 24, true) | |
runOnUiThread { image.setImageBitmap(scaled) } | |
// val monochrome = toMonochrome(bitmap, 56) | |
val inputData = convertBitmapToByteBuffer(scaled) | |
val predictions = classifier.predict(inputData) | |
val duration = System.currentTimeMillis() - start | |
val predictionsString = predictions.asSequence().joinToString(separator = "\n") { "${it.first}: ${it.second}" } | |
val outputText = "$predictionsString\n\nDuration: ${duration}ms" | |
runOnUiThread { text.text = outputText } | |
}).start() | |
} | |
override fun onDestroy() { | |
super.onDestroy() | |
classifier.close() | |
} | |
private fun getBitmapFromAsset(context: Context, filePath: String): Bitmap? { | |
val assetManager = context.assets | |
var istr: InputStream? = null | |
var bitmap: Bitmap? = null | |
try { | |
istr = assetManager.open(filePath) | |
bitmap = BitmapFactory.decodeStream(istr) | |
} catch (e: IOException) { | |
// handle exception | |
} finally { | |
istr?.close() | |
} | |
return bitmap | |
} | |
private fun toMonochrome(bitmap: Bitmap, size: Int): Bitmap { | |
val scaled = Bitmap.createScaledBitmap(bitmap, size, size, false) | |
// convert bitmap to monochrome | |
val monochrome = Bitmap.createBitmap(size, size, Bitmap.Config.ARGB_8888) | |
val canvas = Canvas(monochrome) | |
val ma = ColorMatrix() | |
ma.setSaturation(0f) | |
val paint = Paint() | |
paint.colorFilter = ColorMatrixColorFilter(ma) | |
canvas.drawBitmap(scaled, 0f, 0f, paint) | |
val width = monochrome.width | |
val height = monochrome.height | |
val pixels = IntArray(width * height) | |
monochrome.getPixels(pixels, 0, width, 0, 0, width, height) | |
for (y in 0 until height) { | |
for (x in 0 until width) { | |
val pixel = monochrome.getPixel(x, y) | |
val lowestBit = pixel and 0xff | |
if (lowestBit < 128) { | |
monochrome.setPixel(x, y, Color.BLACK) | |
} else { | |
monochrome.setPixel(x, y, Color.WHITE) | |
} | |
} | |
} | |
return monochrome | |
} | |
private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer { | |
val byteBuffer = ByteBuffer.allocateDirect(bitmap.width * bitmap.height * 3 * 4) | |
byteBuffer.order(ByteOrder.nativeOrder()) | |
val intValues = IntArray(bitmap.width * bitmap.height) | |
bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height) | |
var pixelIndex = 0 | |
for (i in 0 until bitmap.width) { | |
for (j in 0 until bitmap.height) { | |
val pixel = intValues[pixelIndex++] | |
byteBuffer.putFloat(((pixel shr 16 and 0xFF)) / 255f) | |
byteBuffer.putFloat(((pixel shr 8 and 0xFF)) / 255f) | |
byteBuffer.putFloat(((pixel and 0xFF)) / 255f) | |
} | |
} | |
return byteBuffer | |
} | |
interface Classifier { | |
fun predict(input: ByteBuffer): List<Pair<String, Float>> | |
fun close() | |
} | |
class TFMobileClassifier(context: Context, private val modelFilename: String) : Classifier { | |
private val inferenceInterface = Interpreter(loadModelFile(context)) | |
override fun predict(input: ByteBuffer): List<Pair<String, Float>> { | |
val classes = arrayOf("close", "pause", "play", "stop", "other") | |
val predictions = Array(1) { FloatArray(classes.size) } | |
inferenceInterface.run(input, predictions) | |
return predictions[0].mapIndexed { index, score -> Pair(classes[index], score) }.sortedByDescending { it.second } | |
} | |
override fun close() { | |
inferenceInterface.close() | |
} | |
/** Memory-map the model file in Assets. */ | |
@Throws(IOException::class) | |
private fun loadModelFile(context: Context): MappedByteBuffer { | |
val fileDescriptor = context.assets.openFd(modelFilename) | |
val inputStream = FileInputStream(fileDescriptor.fileDescriptor) | |
val fileChannel = inputStream.channel | |
val startOffset = fileDescriptor.startOffset | |
val declaredLength = fileDescriptor.declaredLength | |
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?xml version="1.0" encoding="utf-8"?> | |
<android.support.constraint.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android" | |
xmlns:app="http://schemas.android.com/apk/res-auto" | |
xmlns:tools="http://schemas.android.com/tools" | |
android:layout_width="match_parent" | |
android:layout_height="match_parent" | |
tools:context=".MainActivity"> | |
<ImageView | |
android:id="@+id/image" | |
android:layout_width="100dp" | |
android:layout_height="100dp" | |
app:layout_constraintBottom_toTopOf="@id/text" | |
app:layout_constraintLeft_toLeftOf="parent" | |
app:layout_constraintRight_toRightOf="parent" | |
app:layout_constraintTop_toTopOf="parent" /> | |
<TextView | |
android:id="@+id/text" | |
android:layout_width="wrap_content" | |
android:layout_height="wrap_content" | |
android:text="Hello World!" | |
app:layout_constraintBottom_toBottomOf="parent" | |
app:layout_constraintLeft_toLeftOf="parent" | |
app:layout_constraintRight_toRightOf="parent" | |
app:layout_constraintTop_toTopOf="parent" /> | |
</android.support.constraint.ConstraintLayout> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment