lib/model.dart
:
- Defines:
getPrediction
: calls thepredict
methodgetImagePrediction
: calls thepredictImage
method and then compares the maximum score agains tthe label listgetImagePredictionList
: calls thepredictImage
method, returns all scores
android/src/main/java/io/funn/pytroch_mobile/PyTorchMobilePlugin.java
:
- Intercepts various method calls as essentially a switch statement:
- Methods are:
loadModel
: just loads the model filepredict
:- Loads data, other args
- Input tensor:
getInputTensor(dtype, data, shape)
- Calls a
Convert
method based o nthedtype
, which is an enumeration of data type (eg.FLOAT32
)
- Calls a
- Output tensor:
module.forward(IValue.from(inputTensor)).toTensor()
predictImage
:- Loads image data, other arguments
- image input tensor:
TensorImageUtils.bitmapToFloat32Tensor
- image output tensor:
imageModule.forward(IValue.from(imageInputTensor)).toTensor()
- extract scores from image output tensor
So it seems like it's IValue.from
and either module.forward
or imageModule
doing the work here.
Is it the same on iOS?
ios/Classes/PyTorchMobilePlugin.mm
- Intercepts method calls described in
mode.dart
- Like Android code above, then handles them as a switch statement
- case 1 (
@predict
):- Gets arguments and
TorchModule module
- clone the data for input???
- output: calls the
predict
method from the instancemodule
- Gets arguments and
- case 2 (
@predictImage
):- Gets args and
TorchModule imageModule
- input: calls
UIImageExtension
'sresize
andnormalize
on the image - output: calls
imageModule
'spredictImage
- Gets args and
So the call to module
or imageModule
refer to TorchModule.mm
...
ios/Classes/TorchBridge/TorchModule.mm
predictImage
andpredict
defined herepredict
:- Lots of shuffling stuff between formats here, but I reckon L66 (
outputTensor = _module.forward({tensor}).toTensor()
is the important one
- Lots of shuffling stuff between formats here, but I reckon L66 (
predictImage
:
- Same here, L31 (
outputTensor = _module.forward({tensor}).toTensor()
This looks like the Android code! But I expected different module references when forwarding. Maybe module
and imageModule
are in fact both just identical instantiations of TorchModule
?
On Android, module
and imageModule
are both loaded from modules.get(index
), and modules
is an array populated with modules.add(Module.load(absPath))
, which runs when loadModel
is called.
loadModel
is in turn mirrored in pytorch_mobile.dart
: when it's called in Flutter, it invokes loadModel
in the bridge, getting back an index
. That index
is then used to return Model(index)
, so we go over to model.dart
. it's stored internally in the substantiation and passed to predict
and predictImage
.
Each Model
has an index
, and that index is also the index of the loaded model file (which is opened with Module.load()
, so it's also a module) in the list of modules
in Android. Similarly, in iOS, modules are populated with addObject
(PyTorchMobilePlugin.mm:L27
). But the latter is TorchModule.mm
--- which then has a member torch::jit::script::Module _module
that is loaded when it calls _module = torch::jit::load(filePath.UTF8String)
in the method initWithFileAtPath
So where does Module.load
or modules addObject
come from?
- Android:
import org.pytorch.Module
- iOS:
torch::jit::load()
inLibTorch/LibTorch.h
, which returns atorch::jit::script::Module
Before we crack open PyTorch itself, let's have a look at the PyTorch Android app and iOS app to see how they handle this, since they have Object Detection demos and theoretically use the same interface.
Let's start in the Inference folder.
ObjectDetector.swift
gets a "bundle" (ie. internal app file) path for the YOLOv5s object detector model, then loads the module using InferenceModule(fileAtPath: filePath)
. Where does tha tcome from? Well, ObjectDetection-Bridging-Header.h
loads InreferenceModule.h
. That declares an interface that includes a detectImage
method.
[InferenceModule.mm](https://github.com/pytorch/ios-demo-app/blob/master/ObjectDetection/ObjectDetection/Inference/InferenceModule.mm) imports
Libtorch-Lite/Libtorch-Lite.h, which appears to be where it gets similar
torchlibrary structures like its internal member,
torch::jit::mobile::Module _impl. It has its own version of
initWithFileAtPath, but it calls
torch::jit::_load_for_mobileinstead of
torch::jit::_load`.
Below this, there's detectImage
. Interestingly, it tracks the inference time (presumably to report back for demo use). But it also uses outputTuple = _impl.forward({ tensor }).toTuple()
at the end of the day.
What goes into the forward
method of the module (which seems to be determined by the file)? A tensor
, which is:
torch::from_blob(imageBuffer, { 1, 3, input_height, input_width }, at::kFloat);
Compare the to the Flutter image classifier:
at::Tensor tensor = torch::from_blob(imageBuffer, {1, 3, height, width}, at::kFloat);
Following this it's mostly cleanup, unpacking the output tensor and then unfolding the detected objects (not sure if those are in the form of coordinates and label indcies, subsets of the image buffer or something else).
I'm starting to get the feeling that if you know how to pack the input tensor and unpack the output tensor, Scratch that, see below! We can probably just use customModel
can probably actually do all of this without needing to modify the plugin at all.getImagePredictionList
!
android-demo-app/ObjectDetection/app/src/main/java/org/pytorch/demo/objectdetection / ObjectDetectionActivity.java
The analyzeImage
method, like the the Android Flutter app above, has a outputTuple = mModule.forward(IValue.from(inputTensor)).toTuple()
line (L102`. Let's have a look at the input tensor going into it on the previous line:
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap, PrePostProcessor.NO_MEAN_RGB, PrePostProcessor.NO_STD_RGB);
(Compare to the predictImage
method of the Flutterimage classifier):
final Tensor imageInputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
mean, std);
It seems exactly the same on Android! Can I even skip customModel
and use imageModel
's getImagePredictionList
in the Flutter app instead of getImagePrediction
? Or just slightly modify it to account for a different output structure?
It looks like the methods of packing the input tensor are slightly different on iOS and Android. Let's now see if we can find these fucntions in PyTorch Mobile.
- In iOS:
Libtorch-Lite/Libtorch-Lite.h
orLibTorch/LibTorch.h
- In Android:
org.pytorch.torchvision.TensorImageUtils