Last active
August 5, 2024 22:49
-
-
Save areichman/3fc08eb44fdaf076819316d6346fa75d to your computer and use it in GitHub Desktop.
Perform an image classification using TensorFlow.js and AWS Lambda
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const tf = require('@tensorflow/tfjs'); | |
const jpeg = require('jpeg-js'); | |
const axios = require('axios'); | |
async function handler(event) { | |
try { | |
// extract the model and image urls from the lambda event payload and fetch the data | |
const { modelUrl, imageUrl } = event; | |
const model = await tf.loadLayersModel(modelUrl); | |
const { data: imageBuffer } = await axios({ | |
method: 'get', | |
url: imageUrl, | |
responseType: 'arraybuffer' | |
}); | |
// convert the image data to a tensor object | |
// based on https://dev.to/ibmdeveloper/machine-learning-in-nodejs-with-tensorflowjs-1g1p | |
const { width, height, data: imageData } = jpeg.decode(imageBuffer, { useTArray: true }); | |
const numChannels = 3; | |
const numPixels = width * height; | |
const values = new Int32Array(numPixels * numChannels); | |
for (let i = 0; i < numPixels; i++) { | |
for (let channel = 0; channel < numChannels; ++channel) { | |
values[i * numChannels + channel] = imageData[i * 4 + channel]; | |
} | |
} | |
const imageTensor = tf.tensor(values, [height, width, 3]); | |
// center crop the image data to match the model requirements: | |
// square, 224 pixel resolution | |
const [ height, width ] = imageTensor.shape; | |
const size = Math.min(width, height); | |
const startingHeight = (height - size) / 2; | |
const startingWidth = (width - size) / 2; | |
const cropped = imageTensor.slice([startingHeight, startingWidth], [size, size]); | |
const resized = tf.image.resizeBilinear(cropped, [224, 224]); | |
const batched = resized.expandDims(0); | |
// run the prediction and return the results | |
const result = model.predict(batched); | |
const data = await result.data(); | |
const body = await result.array(); | |
const response = { | |
result: body[0], | |
modelUrl, | |
imageUrl | |
}; | |
console.log(JSON.stringify(response)); | |
return response; | |
} catch (err) { | |
console.log(JSON.stringify({ | |
error: err.message, | |
modelUrl, | |
imageUrl | |
}); | |
} | |
} | |
exports.handler = handler; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment