Last active
July 24, 2023 21:48
-
-
Save jhabr/aced0ee37d6a585616e3edce47706235 to your computer and use it in GitHub Desktop.
Huggingface Distilbert Conversion to TFLite using TF 2.x and Usage on Android
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import TFDistilBertForSequenceClassification, DistilBertConfig | |
import tensorflow as tf | |
# load pretrained model | |
config = DistilBertConfig.from_pretrained('distilbert-base-cased', num_labels=10) | |
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', from_pt=True, config=config) | |
# create input and output | |
SEQUENCE_LENGTH = 100 | |
NO_CLASSES = 10 | |
input_ids = tf.keras.layers.Input(shape=(SEQUENCE_LENGTH,), dtype=tf.int32, name="input_ids", batch_size=1) | |
attention_mask = tf.keras.layers.Input(shape=(SEQUENCE_LENGTH,), dtype=tf.int32, name="attention_mask", batch_size=1) | |
inputs = [input_ids, attention_mask] | |
outputs = model(input_ids, attention_mask)[0] | |
outputs = tf.keras.layers.Dense(NO_CLASSES, activation='softmax')(outputs) | |
model = tf.keras.models.Model(inputs=inputs, outputs=outputs) | |
# export model to tflite | |
converter = tf.lite.TFLiteConverter.from_keras_model(model) | |
converter.experimental_new_converter = True | |
converter.allow_custom_ops = True | |
converter.target_spec.supported_ops = [ | |
tf.lite.OpsSet.TFLITE_BUILTINS, | |
tf.lite.OpsSet.SELECT_TF_OPS | |
] | |
# For conversion with FP16 quantization | |
# converter.target_spec.supported_types = [tf.float16] | |
# converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
# converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] | |
tflite_model = converter.convert() | |
open("distilbert_L100.tflite", 'wb' ).write(tflite_model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import android.content.Context | |
import org.tensorflow.lite.Interpreter | |
import timber.log.Timber | |
import java.io.FileInputStream | |
import java.nio.MappedByteBuffer | |
import java.nio.channels.FileChannel | |
import kotlin.system.measureTimeMillis | |
/* | |
make sure to include the following in the app/build.gradle dependencies: | |
implementation 'org.tensorflow:tensorflow-lite:2.5.0' | |
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.5.0' | |
*/ | |
class TextClassifier(private val context: Context) { | |
companion object { | |
private const val MODEL_PATH = "distilbert_L100.tflite" | |
private const val LABELS_PATH = "labels.txt" | |
} | |
private var model: MappedByteBuffer | |
private var labels: List<String> | |
private var interpreter: Interpreter | |
init { | |
model = loadModel() | |
labels = loadLabels() | |
val options = Interpreter.Options() | |
options.setNumThreads(5) | |
options.setUseNNAPI(true) | |
interpreter = Interpreter(model, options) | |
} | |
private fun loadModel(): MappedByteBuffer { | |
val fileDescriptor = context.assets.openFd(MODEL_PATH) | |
val inputStream = FileInputStream(fileDescriptor.fileDescriptor) | |
val fileChannel: FileChannel = inputStream.channel | |
val startOffset = fileDescriptor.startOffset | |
val declaredLength = fileDescriptor.declaredLength | |
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) | |
} | |
private fun loadLabels(): List<String> { | |
return context.assets.open(LABELS_PATH).bufferedReader().useLines { it.toList() } | |
} | |
// input example: | |
// | |
// inputIds = intArrayOf(102, 4845, 103) | |
// attentionMask = intArrayOf(1, 1, 1) | |
// input = arrayOf(inputIds, attentionMask) | |
fun runInference(input: Array<IntArray>) { | |
val output = mapOf(0 to Array(1) { | |
FloatArray(labels.size) | |
}) | |
interpreter.runForMultipleInputsOutputs(input, output) | |
val probabilities = output[0]?.get(0) | |
Timber.d("Probabilities Results: $probabilities") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment