Skip to content

Instantly share code, notes, and snippets.

@areichman
Last active August 5, 2024 22:49
Show Gist options
  • Save areichman/3fc08eb44fdaf076819316d6346fa75d to your computer and use it in GitHub Desktop.
Save areichman/3fc08eb44fdaf076819316d6346fa75d to your computer and use it in GitHub Desktop.
Perform an image classification using TensorFlow.js and AWS Lambda
const tf = require('@tensorflow/tfjs');
const jpeg = require('jpeg-js');
const axios = require('axios');
async function handler(event) {
try {
// extract the model and image urls from the lambda event payload and fetch the data
const { modelUrl, imageUrl } = event;
const model = await tf.loadLayersModel(modelUrl);
const { data: imageBuffer } = await axios({
method: 'get',
url: imageUrl,
responseType: 'arraybuffer'
});
// convert the image data to a tensor object
// based on https://dev.to/ibmdeveloper/machine-learning-in-nodejs-with-tensorflowjs-1g1p
const { width, height, data: imageData } = jpeg.decode(imageBuffer, { useTArray: true });
const numChannels = 3;
const numPixels = width * height;
const values = new Int32Array(numPixels * numChannels);
for (let i = 0; i < numPixels; i++) {
for (let channel = 0; channel < numChannels; ++channel) {
values[i * numChannels + channel] = imageData[i * 4 + channel];
}
}
const imageTensor = tf.tensor(values, [height, width, 3]);
// center crop the image data to match the model requirements:
// square, 224 pixel resolution
const [ height, width ] = imageTensor.shape;
const size = Math.min(width, height);
const startingHeight = (height - size) / 2;
const startingWidth = (width - size) / 2;
const cropped = imageTensor.slice([startingHeight, startingWidth], [size, size]);
const resized = tf.image.resizeBilinear(cropped, [224, 224]);
const batched = resized.expandDims(0);
// run the prediction and return the results
const result = model.predict(batched);
const data = await result.data();
const body = await result.array();
const response = {
result: body[0],
modelUrl,
imageUrl
};
console.log(JSON.stringify(response));
return response;
} catch (err) {
console.log(JSON.stringify({
error: err.message,
modelUrl,
imageUrl
});
}
}
exports.handler = handler;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment