-
-
Save kabeer11000/f2759c8c3de1413891c4f3ab10f82b65 to your computer and use it in GitHub Desktop.
Using TensorFlow.js with MobileNet models for image classification on Node.js
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"name": "tf-js", | |
"version": "1.0.0", | |
"main": "script.js", | |
"license": "MIT", | |
"dependencies": { | |
"@tensorflow-models/mobilenet": "^0.2.2", | |
"@tensorflow/tfjs": "^0.12.3", | |
"@tensorflow/tfjs-node": "^0.1.9", | |
"jpeg-js": "^0.3.4" | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const tf = require('@tensorflow/tfjs') | |
const mobilenet = require('@tensorflow-models/mobilenet'); | |
require('@tensorflow/tfjs-node') | |
const fs = require('fs'); | |
const jpeg = require('jpeg-js'); | |
const NUMBER_OF_CHANNELS = 3 | |
const readImage = path => { | |
const buf = fs.readFileSync(path) | |
const pixels = jpeg.decode(buf, true) | |
return pixels | |
} | |
const imageByteArray = (image, numChannels) => { | |
const pixels = image.data | |
const numPixels = image.width * image.height; | |
const values = new Int32Array(numPixels * numChannels); | |
for (let i = 0; i < numPixels; i++) { | |
for (let channel = 0; channel < numChannels; ++channel) { | |
values[i * numChannels + channel] = pixels[i * 4 + channel]; | |
} | |
} | |
return values | |
} | |
const imageToInput = (image, numChannels) => { | |
const values = imageByteArray(image, numChannels) | |
const outShape = [image.height, image.width, numChannels]; | |
const input = tf.tensor3d(values, outShape, 'int32'); | |
return input | |
} | |
const loadModel = async path => { | |
const mn = new mobilenet.MobileNet(1, 1); | |
mn.path = `file://${path}` | |
await mn.load() | |
return mn | |
} | |
const classify = async (model, path) => { | |
const image = readImage(path) | |
const input = imageToInput(image, NUMBER_OF_CHANNELS) | |
const mn_model = await loadModel(model) | |
const predictions = await mn_model.classify(input) | |
console.log('classification results:', predictions) | |
} | |
if (process.argv.length !== 4) throw new Error('incorrect arguments: node script.js <MODEL> <IMAGE_FILE>') | |
classify(process.argv[2], process.argv[3]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment