Skip to content

Instantly share code, notes, and snippets.

@shubham0204
Last active June 13, 2019 05:44
Show Gist options
  • Save shubham0204/66b44cdad617a18ca227867f8f6c54e6 to your computer and use it in GitHub Desktop.
Save shubham0204/66b44cdad617a18ca227867f8f6c54e6 to your computer and use it in GitHub Desktop.
@Throws(IOException::class)
private fun loadModelFile(): MappedByteBuffer {
val MODEL_ASSETS_PATH = "model.tflite"
val assetFileDescriptor = assets.openFd(MODEL_ASSETS_PATH)
val fileInputStream = FileInputStream(assetFileDescriptor.fileDescriptor)
val fileChannel = fileInputStream.channel
val startoffset = assetFileDescriptor.startOffset
val declaredLength = assetFileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startoffset, declaredLength)
}
fun classifySequence ( sequence : Array<DoubleArray> ): FloatArray {
val interpreter = Interpreter( loadModelFile() )
val inputs : Array<Array<FloatArray>> = arrayOf(
sequence.map{
it.map {
it.toFloat()
}.toFloatArray()
}.toTypedArray()
)
val outputs : Array<FloatArray> = arrayOf( floatArrayOf( 0.0f , 0.0f ) )
interpreter.run( inputs , outputs )
return outputs[0]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment