Skip to content

Instantly share code, notes, and snippets.

@jimjam-slam
Last active November 19, 2022 08:47
Show Gist options
  • Save jimjam-slam/ccd67d90b0dd190cb4f8d0be33637306 to your computer and use it in GitHub Desktop.
Save jimjam-slam/ccd67d90b0dd190cb4f8d0be33637306 to your computer and use it in GitHub Desktop.
Flutter/PyTorch notes

flutter_pytorch_mobile

lib/model.dart:

  • Defines:
    • getPrediction: calls the predict method
    • getImagePrediction: calls the predictImage method and then compares the maximum score agains tthe label list
    • getImagePredictionList: calls the predictImage method, returns all scores

android/src/main/java/io/funn/pytroch_mobile/PyTorchMobilePlugin.java:

  • Intercepts various method calls as essentially a switch statement:
  • Methods are:
    • loadModel: just loads the model file
    • predict:
      • Loads data, other args
      • Input tensor: getInputTensor(dtype, data, shape)
        • Calls a Convert method based o nthe dtype, which is an enumeration of data type (eg. FLOAT32)
      • Output tensor: module.forward(IValue.from(inputTensor)).toTensor()
    • predictImage:
      • Loads image data, other arguments
      • image input tensor: TensorImageUtils.bitmapToFloat32Tensor
      • image output tensor: imageModule.forward(IValue.from(imageInputTensor)).toTensor()
      • extract scores from image output tensor

So it seems like it's IValue.from and either module.forward or imageModule doing the work here.

Is it the same on iOS?

ios/Classes/PyTorchMobilePlugin.mm

  • Intercepts method calls described in mode.dart
  • Like Android code above, then handles them as a switch statement
  • case 1 (@predict):
    • Gets arguments and TorchModule module
    • clone the data for input???
    • output: calls the predict method from the instance module
  • case 2 (@predictImage):
    • Gets args and TorchModule imageModule
    • input: calls UIImageExtension's resize and normalize on the image
    • output: calls imageModule's predictImage

So the call to module or imageModule refer to TorchModule.mm...

ios/Classes/TorchBridge/TorchModule.mm

  • predictImage and predict defined here
  • predict:
    • Lots of shuffling stuff between formats here, but I reckon L66 (outputTensor = _module.forward({tensor}).toTensor() is the important one
  • predictImage:
  • Same here, L31 (outputTensor = _module.forward({tensor}).toTensor()

This looks like the Android code! But I expected different module references when forwarding. Maybe module and imageModule are in fact both just identical instantiations of TorchModule?

On Android, module and imageModule are both loaded from modules.get(index), and modules is an array populated with modules.add(Module.load(absPath)), which runs when loadModel is called.

loadModel is in turn mirrored in pytorch_mobile.dart: when it's called in Flutter, it invokes loadModel in the bridge, getting back an index. That index is then used to return Model(index), so we go over to model.dart. it's stored internally in the substantiation and passed to predict and predictImage.

tl;dr:

Each Model has an index, and that index is also the index of the loaded model file (which is opened with Module.load(), so it's also a module) in the list of modules in Android. Similarly, in iOS, modules are populated with addObject (PyTorchMobilePlugin.mm:L27). But the latter is TorchModule.mm --- which then has a member torch::jit::script::Module _module that is loaded when it calls _module = torch::jit::load(filePath.UTF8String) in the method initWithFileAtPath

So where does Module.load or modules addObject come from?

  • Android: import org.pytorch.Module
  • iOS: torch::jit::load() in LibTorch/LibTorch.h, which returns a torch::jit::script::Module

Before we crack open PyTorch itself, let's have a look at the PyTorch Android app and iOS app to see how they handle this, since they have Object Detection demos and theoretically use the same interface.

Let's start in the Inference folder.

ObjectDetector.swift gets a "bundle" (ie. internal app file) path for the YOLOv5s object detector model, then loads the module using InferenceModule(fileAtPath: filePath). Where does tha tcome from? Well, ObjectDetection-Bridging-Header.h loads InreferenceModule.h. That declares an interface that includes a detectImage method.

[InferenceModule.mm](https://github.com/pytorch/ios-demo-app/blob/master/ObjectDetection/ObjectDetection/Inference/InferenceModule.mm) imports Libtorch-Lite/Libtorch-Lite.h, which appears to be where it gets similar torchlibrary structures like its internal member,torch::jit::mobile::Module _impl. It has its own version of initWithFileAtPath, but it calls torch::jit::_load_for_mobileinstead oftorch::jit::_load`.

Below this, there's detectImage. Interestingly, it tracks the inference time (presumably to report back for demo use). But it also uses outputTuple = _impl.forward({ tensor }).toTuple() at the end of the day.

What goes into the forward method of the module (which seems to be determined by the file)? A tensor, which is:

torch::from_blob(imageBuffer, { 1, 3, input_height, input_width }, at::kFloat);

Compare the to the Flutter image classifier:

at::Tensor tensor = torch::from_blob(imageBuffer, {1, 3, height, width}, at::kFloat);

Following this it's mostly cleanup, unpacking the output tensor and then unfolding the detected objects (not sure if those are in the form of coordinates and label indcies, subsets of the image buffer or something else).

I'm starting to get the feeling that if you know how to pack the input tensor and unpack the output tensor, customModel can probably actually do all of this without needing to modify the plugin at all. Scratch that, see below! We can probably just use getImagePredictionList!

android-demo-app/ObjectDetection/app/src/main/java/org/pytorch/demo/objectdetection / ObjectDetectionActivity.java

The analyzeImage method, like the the Android Flutter app above, has a outputTuple = mModule.forward(IValue.from(inputTensor)).toTuple() line (L102`. Let's have a look at the input tensor going into it on the previous line:

final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap, PrePostProcessor.NO_MEAN_RGB, PrePostProcessor.NO_STD_RGB);

(Compare to the predictImage method of the Flutterimage classifier):

final Tensor imageInputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
                mean, std);

It seems exactly the same on Android! Can I even skip customModel and use imageModel's getImagePredictionList in the Flutter app instead of getImagePrediction? Or just slightly modify it to account for a different output structure?

It looks like the methods of packing the input tensor are slightly different on iOS and Android. Let's now see if we can find these fucntions in PyTorch Mobile.

  • In iOS: Libtorch-Lite/Libtorch-Lite.h or LibTorch/LibTorch.h
  • In Android: org.pytorch.torchvision.TensorImageUtils
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment