Skip to content

Instantly share code, notes, and snippets.

@woudsma
Last active November 1, 2022 05:09
Show Gist options
  • Save woudsma/d01eeda8998c9ab972d05ec9e9843886 to your computer and use it in GitHub Desktop.
Save woudsma/d01eeda8998c9ab972d05ec9e9843886 to your computer and use it in GitHub Desktop.
Retrain a MobileNet model for the web with TensorFlow for Poets and TensorFlow.js

Retrain a MobileNet model and use it in the browser with TensorFlow.js

DRAFT

Combining TensorFlow for Poets and TensorFlow.js.
Retrain a MobileNet V1 or V2 model on your own dataset using the CPU only.
I'm using a MacBook Pro without Nvidia GPU.

MobileNets can be used for image classification. This guide shows the steps I took to retrain a MobileNet on a custom dataset, and how to convert and use the retrained model in the browser using TensorFlow.js. The total time to set up, retrain the model and use it in the browser can take less than 30 minutes (depending on the size of your dataset).

Example app - HTML/JS and a retrained MobileNet V1/V2 model.


1. Python setup

Set up a virtual environment in Python. This keeps your system clean and dependencies separated. It's good practice not to sudo install packages. You can skip this section if you are already familiar with Python and virtualenv.

We will use virtualenv-burrito, which is a script for installing both virtualenv and virtualenv-wrapper.

cd
curl -sL https://raw.githubusercontent.com/brainsik/virtualenv-burrito/master/virtualenv-burrito.sh | $SHELL
source ~/.venvburrito/startup.sh

# Create project environment
mkvirtualenv myproject
deactivate

Activate an environment: workon foo
Deactivate current environment: deactivate
Create an environment: mkvirtualenv foo
Remove an environment: rmvirtualenv foo
List available environments: lsvirtualenv


2. Retrain a MobileNet model using a custom dataset

If you get stuck at any point, see the TensorFlow for Poets codelab, or this article

Active the project's virtualenv, install TensorFlow.js, and git clone the googlecodelabs/tensorflow-for-poets-2 or tensorflow/hub repository.

# Activate project environment
# Install TensorFlow.js (includes tensorflow, tensorboard, tensorflowjs_converter)
workon myproject
pip install tensorflowjs

# Clone the TensorFlow for Poets repository (MobileNet V1)
git clone https://github.com/googlecodelabs/tensorflow-for-poets-2 retrain-mobilenet-v1
cd retrain-mobilenet-v1

# Clone the TensorFlow Hub repository (MobileNet V2)
git clone https://github.com/tensorflow/hub retrain-mobilenet-v2
cd retrain-mobilenet-v2

Note: all further commands assume the present working directory is retrain-mobilenet-<version>, and project virtualenv is activated.

Run the retrain.py script with the -h (help) flag to see the options of the retrain script.

# cd /path/to/retrain-mobilenet-v1
python -m scripts.retrain -h

# cd /path/to/retrain-mobilenet-v2
python examples/image_retraining/retrain.py -h

Add a dataset

Create a directory tf_files, add folders and add your dataset (folders containing images) to the tf_files/dataset directory. The classification labels used when running inference will be generated from the folder names.

# Create folders
mkdir -p tf_files/{bottlenecks,dataset,models,training_summaries}

# Add your dataset
cp -R /path/to/my-dataset/* tf_files/dataset

# Or use the 'flowers' dataset
curl http://download.tensorflow.org/example_images/flower_photos.tgz | tar xz -C .
cp -R flower_photos/* tf_files/dataset

Start TensorBoard

Start TensorBoard from a new terminal window and visit http://localhost:6006/.

# Open new terminal window
cd /path/to/retrain-mobilenet-<version>
workon myproject
tensorboard --logdir tf_files/training_summaries

Retrain a model using a pre-trained MobileNet V1 model

To retrain a MobileNet V1 model, choose an architecture from this page, and run the retrain.py script. We found that mobilenet_0.50_224 provides both decent accuracy and acceptable filesize (the model takes ~2.3MB after gzip compression). Smaller models such as mobilenet_0.25_128 provide lower accuracy but require less bandwidth, and vice versa.

This will take a few minutes (using the flowers dataset), or longer depending on the size of your dataset.

# Set environment variables
IMAGE_SIZE=128
ARCHITECTURE=mobilenet_0.25_$IMAGE_SIZE

# Start training
python -m scripts.retrain \
  --image_dir=tf_files/dataset \
  --model_dir=tf_files/models \
  --architecture=$ARCHITECTURE \
  --output_graph=tf_files/retrained_graph.pb \
  --output_labels=tf_files/retrained_labels.txt \
  --bottleneck_dir=tf_files/bottlenecks \
  --summaries_dir=tf_files/training_summaries/$ARCHITECTURE \
  --how_many_training_steps=400 \
  --learning_rate=0.001

Note: keep an eye on TensorBoard (http://localhost:6006) during training.

For more information and how to adjust hyperparameters, check out the full TensorFlow for Poets codelab.

Retrain a model using a pre-trained MobileNet V2 model

Reference article.
Pick a TFHub module from this page, and copy the link to the pre-trained model with type feature_vector.

# Set environment variables
MODULE=https://tfhub.dev/google/imagenet/mobilenet_v2_035_224/feature_vector/2

# Start training
python examples/image_retraining/retrain.py \
  --image_dir=tf_files/dataset \
  --tfhub_module=$MODULE \
  --output_graph=retrained_graph.pb \
  --output_labels=retrained_labels.txt \
  --bottleneck_dir=tf_files/bottlenecks \
  --summaries_dir=tf_files/training_summaries \
  --intermediate_output_graphs_dir=tf_files/intermediate_graphs \
  --intermediate_store_frequency=500 \
  --saved_model_dir=tf_files/saved_model \
  --how_many_training_steps=2000 \
  --learning_rate=0.0333

Test the model by classifying an image

Classify an image using the label_image.py script.
(e.g. tf_files/dataset/daisy/21652746_cc379e0eea_m.jpg if you've retrained the model on the flowers dataset).

python label_image.py \
  --graph=tf_files/retrained_graph.pb \
  --input_width=$IMAGE_SIZE \
  --input_height=$IMAGE_SIZE \
  --image=tf_files/dataset/daisy/21652746_cc379e0eea_m.jpg  

# Top result should be 'daisy'

3. Optimize for the web

Quantize graph

Quantize the retrained graph using the quantize_graph.py script. Although you could use and serve the retrained graph, serving the quantized graph saves bandwidth when using gzip compression.

python quantize_graph.py \
  --input=tf_files/retrained_graph.pb \
  --output=tf_files/quantized_graph.pb \
  --output_node_names=final_result \
  --mode=weights_rounded

Optional: compare size of gzipped graphs.

gzip -k tf_files/retrained_graph.pb tf_files/quantized_graph.pb
du -h tf_files/*.gz

# Clean up
rm tf_files/*.gz

Convert to TensorFlow.js model

Convert the quantized retrained graph to a TensorFlow.js compatible model using tensorflowjs_converter, and save in a new tf_files/web folder.

tensorflowjs_converter \
  --input_format=tf_frozen_model \
  --output_node_names=final_result \
  tf_files/quantized_graph.pb \
  tf_files/web

Add labels

Create a JSON file from retrained_labels.txt using jq, this way we can easily import the dataset labels in JavaScript.

# Install 'jq' once with homebrew (https://brew.sh/)
brew install jq

# Create JSON file from newline-delimited text file
cat tf_files/retrained_labels.txt | jq -Rsc '. / "\n" - [""]' > tf_files/web/labels.json
cat tf_files/web/labels.json

Returns: ["daisy","dandelion","roses","sunflowers","tulips"]

Folder structure after running tensorflowjs_converter and converting the dataset labels to labels.json.

/path/to/tf_files
├── retrained_graph.pb
├── retrained_labels.txt
├── quantized_graph.pb
├── web
│   ├── group1-shard1of1
│   ├── tensorflowjs_model.pb
│   ├── labels.json
│   └── weights_manifest.json
├── bottlenecks
│   └── ...
├── dataset
│   └── ...
├── training_summaries
│   └── ...
└── models|intermediate_graphs|saved_model|...
    └── ...

Optional: check gzipped TensorFlow.js model size.

tar -czf tf_files/web.tar.gz tf_files/web
du -h tf_files/web.tar.gz

# Clean up
rm tf_files/web.tar.gz

4. Classifying images in the browser

Create an app to run predictions in the browser using the retrained model converted by tensorflowjs_converter.

With a few lines of code, we can classify an image using the retrained model. In this example, we use an <img> element as input to get a prediction. Available input types: ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement.

Prepare app folder structure and install dependencies

Or clone the example repository
# Create app folder structure
mkdir -p myproject-frontend/{public/assets/{model,images},src}
cd myproject-frontend

# Create HTML/JS files
touch public/index.html src/index.js

# Install dependencies
npm init -y
npm install react-scripts @tensorflow/tfjs @tensorflow/tfjs-core @tensorflow/tfjs-converter

# Copy web model files to assets folder
# Move the labels JSON file into the src folder
cp -R /path/to/tf_files/web/* public/assets/model
mv public/assets/model/labels.json src/labels.json

# Add a few test images to public/assets/images manually

I'm using the react-scripts package as development server and build tool (used by create-react-app). This saves some time writing webpack configs, etc.

App folder structure after setting up.

/path/to/myproject-frontend
├── node_modules
│   └── ...
├── package-lock.json
├── package.json
├── public
│   ├── assets
│   │   ├── images
│   │   │   └── some-flower.jpg
│   │   └── model
│   │       ├── group1-shard1of1
│   │       ├── tensorflowjs_model.pb
│   │       └── weights_manifest.json
│   └── index.html
└── src
    ├── index.js
    └── labels.json

Add HTML / JS

Edit public/index.html to:

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Image classifier</title>
  </head>
  <body>
    <img id="input" src="assets/images/some-flower.jpg" />
    <pre id="output"></pre>
  </body>
</html>

Edit src/index.js to:

import * as tf from '@tensorflow/tfjs'
import { loadFrozenModel } from '@tensorflow/tfjs-converter'
import labels from './labels.json'

const ASSETS_URL = `${window.location.origin}/assets`
const MODEL_URL = `${ASSETS_URL}/model/tensorflowjs_model.pb`
const WEIGHTS_URL = `${ASSETS_URL}/model/weights_manifest.json`
const IMAGE_SIZE = 128 // Model input size

const loadModel = async () => {
  const model = await loadFrozenModel(MODEL_URL, WEIGHTS_URL)
  // Warm up GPU
  const input = tf.zeros([1, IMAGE_SIZE, IMAGE_SIZE, 3])
  model.predict({ input }) // MobileNet V1
  // model.predict({ Placeholder: input }) // MobileNet V2
  return model
}

const predict = async (img, model) => {
  const t0 = performance.now()
  const image = tf.fromPixels(img).toFloat()
  const resized = tf.image.resizeBilinear(image, [IMAGE_SIZE, IMAGE_SIZE])
  const offset = tf.scalar(255 / 2)
  const normalized = resized.sub(offset).div(offset)
  const input = normalized.expandDims(0)
  const output = await tf.tidy(() => model.predict({ input })).data() // MobileNet V1
  // const output = await tf.tidy(() => model.predict({ Placeholder: input })).data() // MobileNet V2
  const predictions = labels
    .map((label, index) => ({ label, accuracy: output[index] }))
    .sort((a, b) => b.accuracy - a.accuracy)
  const time = `${(performance.now() - t0).toFixed(1)} ms`
  return { predictions, time }
}

const start = async () => {
  const input = document.getElementById('input')
  const output = document.getElementById('output')
  const model = await loadModel()
  const predictions = await predict(input, model)
  output.append(JSON.stringify(predictions, null, 2))
}

start()

Add to package.json scripts:

"scripts": {
  "start": "react-scripts start",
  "build": "react-scripts build"
}

Run the app

Start the development server, run npm start.
(or npx react-scripts start)
Opens browser window at http://localhost:3000.
Watches project files and auto-reloads browser on change.

Create production build, run npm run build.
(or npx react-scripts build)
Outputs build folder with static assets.


Result

MobileNet_V1_0.25_224: 2MB gzipped
MobileNet_V2_0.35_224: 1.1MB gzipped

Result


gzip

It's possible to save bandwidth by serving static assets (including the model) by using gzip compression. This can be done manually or by enabling gzip in your server config.

For example, add to /etc/nginx/conf.d/default.conf

server {
  ...

  gzip on;
  gzip_vary on;
  gzip_static on;
  gzip_types text/plain application/javascript application/octet-stream;
  gzip_min_length 256;
}

If you are using CloudFlare CDN, make sure to disable Brotli compression (for some reason it does not serve application/octet-stream files with their original gzip compression).


Please let me know if you notice any mistakes or things to improve. I'm always open to suggestions and feedback!

Credits to the creators of TensorFlow for Poets and TensorFlow.js. This guide is basically a combination of the original TensorFlow for Poets guide and the TensorFlow.js documentation. Thanks to Mateusz Budzar for a guide on how to retrain a MobileNet V2 model. Also, check out ml5js!

@woudsma
Copy link
Author

woudsma commented Sep 12, 2019

Hi!

Great tutorial! I've followed it all the way through, I've converted the model and gotten the site to run, but I keep getting this error when I go to localhost:
Attempted import error: 'loadFrozenModel' is not exported from '@tensorflow/tfjs-converter'.

I was wondering if you have encountered this issue while you made this or know what to do to fix it. Or maybe someone else has any ideas?
Let me know!

I've not encountered that error! It could be that tfjs doesn't export loadFrozenModel anymore since I wrote the guide. It's probably a version issue. Do you have the same problem if you clone and run the example app?

The latest version of tfjs has a function loadGraphModel, can you try if that works instead of loadFrozenModel?
https://js.tensorflow.org/api/1.2.6/#loadGraphModel

// so instead of
const model = await loadFrozenModel(MODEL_URL, WEIGHTS_URL)
// try
const model = await tf.loadGraphModel(MODEL_URL) 

@dvbeelen
Copy link

Hi!
Great tutorial! I've followed it all the way through, I've converted the model and gotten the site to run, but I keep getting this error when I go to localhost:
Attempted import error: 'loadFrozenModel' is not exported from '@tensorflow/tfjs-converter'.
I was wondering if you have encountered this issue while you made this or know what to do to fix it. Or maybe someone else has any ideas?
Let me know!

I've not encountered that error! It could be that tfjs doesn't export loadFrozenModel anymore since I wrote the guide. It's probably a version issue. Do you have the same problem if you clone and run the example app?

The latest version of tfjs has a function loadGraphModel, can you try if that works instead of loadFrozenModel?
https://js.tensorflow.org/api/1.2.6/#loadGraphModel

// so instead of
const model = await loadFrozenModel(MODEL_URL, WEIGHTS_URL)
// try
const model = await tf.loadGraphModel(MODEL_URL) 

Thank you! This was indeed the problem. Changing the function to loadGraphModel also required me to reconvert the model to put out a .json file. So like this:

tensorflowjs_converter   --input_format=   --output_node_names=final_result --output_format = OUTPUT_JSON   
tf_files/quantized_graph.pb   tf_files/web

Right now, I have only one issue left that I can't seem to fix. I changed the img-path to my own images-files, but when I run the code, I keep receiving this error:

Unhandled Rejection (InvalidStateError): Failed to execute 'drawImage' on 'OffscreenCanvasRenderingContext2D': The 
HTMLImageElement provided is in the 'broken' state.

Perhaps you know what is causing this? I can't seem to figure it out.

@woudsma
Copy link
Author

woudsma commented Sep 14, 2019

Nice!
And for the drawImage issue, it sounds like the image file is not loaded yet correctly. Could be related to:
nwjs/nw.js#2470 or
https://stackoverflow.com/questions/22430671/javascript-failed-to-execute-drawimage
Make sure to run predict only after loading the image, such as:

const img = new Image()
img.onload = async () => {
  const predictions = await predict(model, img)
}
img.src = '...'

I think that could fix your issue.

In case you're trying to do something with the OffscreenCanvas API, that is not supported yet in most browsers unfortunately:
https://caniuse.com/#search=offscreencanvas

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment