$ node mobilenet-node.js ./model image.jpg
-
-
Save kostasx/1562671045aee2c0eb98363c69aecae9 to your computer and use it in GitHub Desktop.
Using TensorFlow.js with MobileNet models for image classification on Node.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const tf = require('@tensorflow/tfjs') | |
const mobilenet = require('@tensorflow-models/mobilenet'); | |
require('@tensorflow/tfjs-node') | |
const fs = require('fs'); | |
const jpeg = require('jpeg-js'); | |
const NUMBER_OF_CHANNELS = 3 | |
const readImage = path => { | |
const buf = fs.readFileSync(path) | |
const pixels = jpeg.decode(buf, true) | |
return pixels | |
} | |
const imageByteArray = (image, numChannels) => { | |
const pixels = image.data | |
const numPixels = image.width * image.height; | |
const values = new Int32Array(numPixels * numChannels); | |
for (let i = 0; i < numPixels; i++) { | |
for (let channel = 0; channel < numChannels; ++channel) { | |
values[i * numChannels + channel] = pixels[i * 4 + channel]; | |
} | |
} | |
return values | |
} | |
const imageToInput = (image, numChannels) => { | |
const values = imageByteArray(image, numChannels) | |
const outShape = [image.height, image.width, numChannels]; | |
const input = tf.tensor3d(values, outShape, 'int32'); | |
return input | |
} | |
const loadModel = async path => { | |
const mn = new mobilenet.MobileNet("1.00", "1.00"); | |
mn.path = `file://${path}` | |
await mn.load() | |
return mn | |
} | |
const classify = async (model, path) => { | |
const image = readImage(path) | |
const input = imageToInput(image, NUMBER_OF_CHANNELS) | |
const mn_model = await loadModel(model) | |
const predictions = await mn_model.classify(input) | |
console.log('classification results:', predictions) | |
} | |
if (process.argv.length !== 4) throw new Error('incorrect arguments: node script.js <MODEL> <IMAGE_FILE>') | |
classify(process.argv[2], process.argv[3]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"name": "tf-js", | |
"version": "1.0.0", | |
"main": "script.js", | |
"license": "MIT", | |
"dependencies": { | |
"@tensorflow-models/mobilenet": "^1.0.1", | |
"@tensorflow/tfjs": "^1.2.2", | |
"@tensorflow/tfjs-node": "^1.2.3", | |
"jpeg-js": "^0.3.5" | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment