Skip to content

Instantly share code, notes, and snippets.

@makorowy
Created February 28, 2018 06:57
Show Gist options
  • Save makorowy/905a764523659085efdcca47235a967f to your computer and use it in GitHub Desktop.
Save makorowy/905a764523659085efdcca47235a967f to your computer and use it in GitHub Desktop.
TensorFlow Hot or Not example. Image classifier.
private const val ENABLE_LOG_STATS = false
class ImageClassifier (
private val inputName: String,
private val outputName: String,
private val imageSize: Long,
private val labels: List<String>,
private val imageBitmapPixels: IntArray,
private val imageNormalizedPixels: FloatArray,
private val results: FloatArray,
private val tensorFlowInference: TensorFlowInferenceInterface
) : Classifier {
override fun recognizeImage(bitmap: Bitmap): Result {
preprocessImageToNormalizedFloats(bitmap)
classifyImageToOutputs()
val outputQueue = getResults()
return outputQueue.poll()
}
private fun preprocessImageToNormalizedFloats(bitmap: Bitmap) {
// Preprocess the image data from 0-255 int to normalized float based
// on the provided parameters.
}
private fun classifyImageToOutputs() {
//feed the classifier with the data via input
tensorFlowInference.feed(inputName, imageNormalizedPixels,
1L, imageSize, imageSize, COLOR_CHANNELS.toLong())
//run the classification
tensorFlowInference.run(arrayOf(outputName), ENABLE_LOG_STATS)
//get the results from the ouptput
tensorFlowInference.fetch(outputName, results)
}
private fun getResults(): PriorityQueue<Result> {
val outputQueue = createOutputQueue()
results.indices.mapTo(outputQueue) { Result(labels[it], results[it]) }
return outputQueue
}
private fun createOutputQueue(): PriorityQueue<Result> {
return PriorityQueue(
initialCapacity = labels.size,
Comparator { (_, rConfidence), (_, lConfidence) ->
Float.compare(lConfidence, rConfidence) })
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment