Skip to content

Instantly share code, notes, and snippets.

@pboos
Last active October 18, 2018 06:39
Show Gist options
  • Save pboos/117df482bebb16587b909290599467c5 to your computer and use it in GitHub Desktop.
Save pboos/117df482bebb16587b909290599467c5 to your computer and use it in GitHub Desktop.
Tensorflow Lite Android
from __future__ import division, print_function, absolute_import
# library for optmising inference
from tensorflow.python.tools import optimize_for_inference_lib
from tensorflow.python.tools import freeze_graph
import tensorflow as tf
# Higher level API tflearn
import tflearn
from tflearn.data_utils import shuffle, to_categorical
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression
from tflearn.data_preprocessing import ImagePreprocessing
from tflearn.data_augmentation import ImageAugmentation
from tflearn.data_utils import image_preloader
import numpy as np
# Data loading and preprocessing
#helper functions to download the CIFAR 10 data and load them dynamically
# from tflearn.datasets import cifar10
# (X, Y), (X_test, Y_test) = cifar10.load_data()
# X, Y = shuffle(X, Y)
# Y = to_categorical(Y,10)
# Y_test = to_categorical(Y_test,10)
IMAGE_FOLDER = 'datasets/button_images'
TRAIN_DATA = 'datasets/training_data.txt'
TEST_DATA = 'datasets/test_data.txt'
VALIDATION_DATA = 'datasets/validation_data.txt'
IMAGE_SIZE=24
train_proportion=0.7
test_proportion=0.2
validation_proportion=0.1
import glob
import os.path
import random
import math
# classes = filter(lambda f: not f.startswith('.'), os.listdir(IMAGE_FOLDER))
# classes.sort(key=str.lower)
classes = ['close', 'pause', 'play', 'stop', 'other']
nrOfClasses = len(classes)
print('Classes: ' + str(classes))
filesDepth2 = glob.glob(IMAGE_FOLDER + '/*/*')
images = filter(lambda f: not os.path.isdir(f), filesDepth2)
random.shuffle(images)
dir_path = os.path.dirname(os.path.realpath(__file__))
def createDataFile(images, skipPercentage, percentage, dataFile):
total = len(images)
fr = open(dataFile, 'w')
start = int(math.ceil(skipPercentage * total))
end = int(math.ceil((skipPercentage + percentage) * total))
images_subset = images[start:end]
for filename in images_subset:
startClass = len(IMAGE_FOLDER) + 1
endClass = filename.index('/', startClass)
className = filename[startClass:endClass]
fullPath = dir_path + '/' + filename
classNameInt = classes.index(className) if className in classes else -1
if classNameInt != -1:
fr.write(fullPath + ' ' + str(classNameInt) + '\n')
fr.close()
createDataFile(images, 0.0, 0.7, TRAIN_DATA)
createDataFile(images, 0.7, 0.9, TEST_DATA)
createDataFile(images, 0.9, 1.0, VALIDATION_DATA)
# TODO maybe use grayscale=True
X_train, Y_train = image_preloader(TRAIN_DATA, image_shape=(IMAGE_SIZE,IMAGE_SIZE),mode='file', categorical_labels=True,normalize=True)
X_test, Y_test = image_preloader(TEST_DATA, image_shape=(IMAGE_SIZE,IMAGE_SIZE),mode='file', categorical_labels=True,normalize=True)
X_val, Y_val = image_preloader(VALIDATION_DATA, image_shape=(IMAGE_SIZE,IMAGE_SIZE),mode='file', categorical_labels=True,normalize=True)
# input image
x = tf.placeholder(tf.float32,shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3] , name="ipnode")
# input class
y_ = tf.placeholder(tf.float32,shape=[None, nrOfClasses] , name='input_class')
# AlexNet architecture
input_layer = x
network = conv_2d(input_layer, IMAGE_SIZE, 3, activation='relu')
network = max_pool_2d(network, 2)
network = conv_2d(network, 64, 3, activation='relu')
network = conv_2d(network, 64, 3, activation='relu')
network = max_pool_2d(network, 2)
network = fully_connected(network, 512, activation='relu')
network = fully_connected(network, nrOfClasses, activation='linear')
y_predicted = tf.nn.softmax(network , name="opnode")
#loss function
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_predicted+np.exp(-nrOfClasses)), reduction_indices=[1]))
#optimiser -
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#calculating accuracy of our model
correct_prediction = tf.equal(tf.argmax(y_predicted,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#TensorFlow session
sess = tf.Session()
#initialising variables
init = tf.global_variables_initializer()
sess.run(init)
#tensorboard for better visualisation
writer =tf.summary.FileWriter('tensorboard/', sess.graph)
epoch=50 # run for more iterations according your hardware's power
#change batch size according to your hardware's power. For GPU's use batch size in powers of 2 like 2,4,8,16...
batch_size=32
no_itr_per_epoch=len(X_train)//batch_size
n_test=len(X_test) #number of test samples
# Commencing training process
for iteration in range(epoch):
print("Iteration no: {} ".format(iteration))
previous_batch=0
# Do our mini batches:
for i in range(no_itr_per_epoch):
current_batch=previous_batch+batch_size
x_input=X_train[previous_batch:current_batch]
x_images=np.reshape(x_input,[batch_size,IMAGE_SIZE,IMAGE_SIZE,3])
y_input=Y_train[previous_batch:current_batch]
y_label=np.reshape(y_input,[batch_size,nrOfClasses])
previous_batch=previous_batch+batch_size
_,loss=sess.run([train_step, cross_entropy], feed_dict={x: x_images,y_: y_label})
#if i % 100==0 :
#print ("Training loss : {}" .format(loss))
x_test_images=np.reshape(X_test[0:n_test],[n_test,IMAGE_SIZE,IMAGE_SIZE,3])
y_test_labels=np.reshape(Y_test[0:n_test],[n_test,nrOfClasses])
Accuracy_test=sess.run(accuracy,
feed_dict={
x: x_test_images ,
y_: y_test_labels
})
# Accuracy of the test set
Accuracy_test=round(Accuracy_test*100,2)
print("Accuracy :: Test_set {} % " .format(Accuracy_test))
#####################
#####################
# saving the graph
saver = tf.train.Saver()
model_directory='model_files/'
tf.train.write_graph(sess.graph_def, model_directory, 'savegraph.pbtxt')
saver.save(sess, 'model_files/model.ckpt')
sess.close()
#################
## Freeze the graph
#################
MODEL_NAME = 'button'
input_graph_path = 'model_files/savegraph.pbtxt'
checkpoint_path = 'model_files/model.ckpt'
input_saver_def_path = ""
input_binary = False
input_node_names = "ipnode"
output_node_names = "opnode"
input_nodes = tf.placeholder(tf.float32,shape=[1, IMAGE_SIZE, IMAGE_SIZE, 3], name=input_node_names)
output_nodes = y_predicted
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = 'model_files/model_frozen_' + MODEL_NAME + '.pb'
output_optimized_graph_name = 'model_files/model_optimized_' + MODEL_NAME + '.pb'
output_converted_graph_name = 'model_files/model_converted_' + MODEL_NAME + '.tflite'
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_frozen_graph_name, clear_devices, "")
#################
## optimize graph
#################
input_graph_def = tf.GraphDef()
with tf.gfile.Open(output_frozen_graph_name, "r") as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
[input_node_names], # an array of the input node(s)
[output_node_names], # an array of output nodes
tf.float32.as_datatype_enum)
# save optimized graph
f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
f.write(output_graph_def.SerializeToString())
#################
## convert graph
#################
from subprocess import call
call([
"toco",
"--graph_def_file=" + output_frozen_graph_name,
"--input_format=TENSORFLOW_GRAPHDEF",
"--output_format=TFLITE",
"--output_file=" + output_converted_graph_name,
"--input_shape=1," + str(IMAGE_SIZE) + "," + str(IMAGE_SIZE) + ",3",
"--input_type=FLOAT",
"--input_array=" + input_node_names,
"--output_array=" + output_node_names,
"--inference_type=FLOAT",
"--inference_input_type=FLOAT"
])
import numpy as np
import tensorflow as tf
# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path="model_files/model_converted_button.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test model on random input data.
input_shape = input_details[0]['shape']
output_shape = output_details[0]['shape']
# change the following line to feed into your own data.
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print('input name: ' + str(input_details[0]['name']))
print('input shape: ' + str(input_shape))
print('output name: ' + str(output_details[0]['name']))
print('output shape: ' + str(output_shape))
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import time
import numpy as np
import tensorflow as tf
def load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
input_mean=0, input_std=255):
print(input_mean)
print(input_std)
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
if file_name.endswith(".png"):
image_reader = tf.image.decode_png(file_reader, channels = 3,
name='png_reader')
elif file_name.endswith(".gif"):
image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
name='gif_reader'))
elif file_name.endswith(".bmp"):
image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
else:
image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0)
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
sess = tf.Session()
result = sess.run(normalized)
return result
def load_labels(label_file):
label = []
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
for l in proto_as_ascii_lines:
label.append(l.rstrip())
return label
if __name__ == "__main__":
file_name = "datasets/button_images/other/1.jpeg"
model_file = "model_files/model_optimized_button.pb"
label_file = "labels.txt"
input_height = 24
input_width = 24
input_mean = 0
input_std = 255
input_layer = "ipnode"
output_layer = "opnode"
parser = argparse.ArgumentParser()
parser.add_argument("--image", help="image to be processed")
parser.add_argument("--graph", help="graph/model to be executed")
parser.add_argument("--labels", help="name of file containing labels")
parser.add_argument("--input_height", type=int, help="input height")
parser.add_argument("--input_width", type=int, help="input width")
parser.add_argument("--input_mean", type=int, help="input mean")
parser.add_argument("--input_std", type=int, help="input std")
parser.add_argument("--input_layer", help="name of input layer")
parser.add_argument("--output_layer", help="name of output layer")
args = parser.parse_args()
if args.graph:
model_file = args.graph
if args.image:
file_name = args.image
if args.labels:
label_file = args.labels
if args.input_height:
input_height = args.input_height
if args.input_width:
input_width = args.input_width
if args.input_mean is not None:
input_mean = args.input_mean
if args.input_std:
input_std = args.input_std
if args.input_layer:
input_layer = args.input_layer
if args.output_layer:
output_layer = args.output_layer
graph = load_graph(model_file)
t = read_tensor_from_image_file(file_name,
input_height=input_height,
input_width=input_width,
input_mean=input_mean,
input_std=input_std)
input_name = "import/" + input_layer
output_name = "import/" + output_layer
input_operation = graph.get_operation_by_name(input_name)
output_operation = graph.get_operation_by_name(output_name)
with tf.Session(graph=graph) as sess:
start = time.time()
results = sess.run(output_operation.outputs[0],
{input_operation.outputs[0]: t})
end=time.time()
results = np.squeeze(results)
top_k = results.argsort()[-5:][::-1]
labels = load_labels(label_file)
print('\nEvaluation time (1-image): {:.3f}s\n'.format(end-start))
template = "{} (score={:0.5f})"
for i in top_k:
print(template.format(labels[i], results[i]))
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import time
import numpy as np
import tensorflow as tf
def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
input_mean=0, input_std=255):
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
if file_name.endswith(".png"):
image_reader = tf.image.decode_png(file_reader, channels = 3,
name='png_reader')
elif file_name.endswith(".gif"):
image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
name='gif_reader'))
elif file_name.endswith(".bmp"):
image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
else:
image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0)
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
sess = tf.Session()
result = sess.run(normalized)
return result
def load_labels(label_file):
label = []
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
for l in proto_as_ascii_lines:
label.append(l.rstrip())
return label
if __name__ == "__main__":
file_name = "datasets/button_images/other/1.jpeg"
model_file = "model_files/model_converted_button.tflite"
label_file = "labels.txt"
input_height = 24
input_width = 24
input_mean = 0
input_std = 255
input_layer = "ipnode"
output_layer = "opnode"
parser = argparse.ArgumentParser()
parser.add_argument("--image", help="image to be processed")
parser.add_argument("--graph", help="graph/model to be executed")
parser.add_argument("--labels", help="name of file containing labels")
parser.add_argument("--input_height", type=int, help="input height")
parser.add_argument("--input_width", type=int, help="input width")
parser.add_argument("--input_mean", type=int, help="input mean")
parser.add_argument("--input_std", type=int, help="input std")
args = parser.parse_args()
if args.graph:
model_file = args.graph
if args.image:
file_name = args.image
if args.labels:
label_file = args.labels
if args.input_height:
input_height = args.input_height
if args.input_width:
input_width = args.input_width
if args.input_mean is not None:
input_mean = args.input_mean
if args.input_std:
input_std = args.input_std
# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path=model_file)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test model on random input data.
input_shape = input_details[0]['shape']
t = read_tensor_from_image_file(file_name,
input_height=input_height,
input_width=input_width,
input_mean=input_mean,
input_std=input_std)
interpreter.set_tensor(input_details[0]['index'], t)
interpreter.invoke()
results = interpreter.get_tensor(output_details[0]['index'])
results = np.squeeze(results)
top_k = results.argsort()[-5:][::-1]
labels = load_labels(label_file)
template = "{} (score={:0.5f})"
for i in top_k:
print(template.format(labels[i], results[i]))
package com.carecon.tensorflow
import android.content.Context
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.ColorMatrix
import android.graphics.ColorMatrixColorFilter
import android.graphics.Paint
import android.os.Bundle
import android.support.v7.app.AppCompatActivity
import kotlinx.android.synthetic.main.activity_main.*
import org.tensorflow.lite.Interpreter
import java.io.FileInputStream
import java.io.IOException
import java.io.InputStream
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
class MainActivity : AppCompatActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
runTensorflowLite()
}
private val classifier: Classifier by lazy { TFMobileClassifier(this, modelFilename = "model.tflite") }
private fun runTensorflowLite() {
Thread(Runnable {
val start = System.currentTimeMillis()
val bitmap = getBitmapFromAsset(this, "button.jpg") ?: return@Runnable
val scaled = Bitmap.createScaledBitmap(bitmap, 24, 24, true)
runOnUiThread { image.setImageBitmap(scaled) }
// val monochrome = toMonochrome(bitmap, 56)
val inputData = convertBitmapToByteBuffer(scaled)
val predictions = classifier.predict(inputData)
val duration = System.currentTimeMillis() - start
val predictionsString = predictions.asSequence().joinToString(separator = "\n") { "${it.first}: ${it.second}" }
val outputText = "$predictionsString\n\nDuration: ${duration}ms"
runOnUiThread { text.text = outputText }
}).start()
}
override fun onDestroy() {
super.onDestroy()
classifier.close()
}
private fun getBitmapFromAsset(context: Context, filePath: String): Bitmap? {
val assetManager = context.assets
var istr: InputStream? = null
var bitmap: Bitmap? = null
try {
istr = assetManager.open(filePath)
bitmap = BitmapFactory.decodeStream(istr)
} catch (e: IOException) {
// handle exception
} finally {
istr?.close()
}
return bitmap
}
private fun toMonochrome(bitmap: Bitmap, size: Int): Bitmap {
val scaled = Bitmap.createScaledBitmap(bitmap, size, size, false)
// convert bitmap to monochrome
val monochrome = Bitmap.createBitmap(size, size, Bitmap.Config.ARGB_8888)
val canvas = Canvas(monochrome)
val ma = ColorMatrix()
ma.setSaturation(0f)
val paint = Paint()
paint.colorFilter = ColorMatrixColorFilter(ma)
canvas.drawBitmap(scaled, 0f, 0f, paint)
val width = monochrome.width
val height = monochrome.height
val pixels = IntArray(width * height)
monochrome.getPixels(pixels, 0, width, 0, 0, width, height)
for (y in 0 until height) {
for (x in 0 until width) {
val pixel = monochrome.getPixel(x, y)
val lowestBit = pixel and 0xff
if (lowestBit < 128) {
monochrome.setPixel(x, y, Color.BLACK)
} else {
monochrome.setPixel(x, y, Color.WHITE)
}
}
}
return monochrome
}
private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer {
val byteBuffer = ByteBuffer.allocateDirect(bitmap.width * bitmap.height * 3 * 4)
byteBuffer.order(ByteOrder.nativeOrder())
val intValues = IntArray(bitmap.width * bitmap.height)
bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
var pixelIndex = 0
for (i in 0 until bitmap.width) {
for (j in 0 until bitmap.height) {
val pixel = intValues[pixelIndex++]
byteBuffer.putFloat(((pixel shr 16 and 0xFF)) / 255f)
byteBuffer.putFloat(((pixel shr 8 and 0xFF)) / 255f)
byteBuffer.putFloat(((pixel and 0xFF)) / 255f)
}
}
return byteBuffer
}
interface Classifier {
fun predict(input: ByteBuffer): List<Pair<String, Float>>
fun close()
}
class TFMobileClassifier(context: Context, private val modelFilename: String) : Classifier {
private val inferenceInterface = Interpreter(loadModelFile(context))
override fun predict(input: ByteBuffer): List<Pair<String, Float>> {
val classes = arrayOf("close", "pause", "play", "stop", "other")
val predictions = Array(1) { FloatArray(classes.size) }
inferenceInterface.run(input, predictions)
return predictions[0].mapIndexed { index, score -> Pair(classes[index], score) }.sortedByDescending { it.second }
}
override fun close() {
inferenceInterface.close()
}
/** Memory-map the model file in Assets. */
@Throws(IOException::class)
private fun loadModelFile(context: Context): MappedByteBuffer {
val fileDescriptor = context.assets.openFd(modelFilename)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
}
}
<?xml version="1.0" encoding="utf-8"?>
<android.support.constraint.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<ImageView
android:id="@+id/image"
android:layout_width="100dp"
android:layout_height="100dp"
app:layout_constraintBottom_toTopOf="@id/text"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toTopOf="parent" />
<TextView
android:id="@+id/text"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Hello World!"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toTopOf="parent" />
</android.support.constraint.ConstraintLayout>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment