Skip to content

Instantly share code, notes, and snippets.

@radi-cho
Created September 23, 2018 18:57
Show Gist options
  • Save radi-cho/eda1970d6ea7f76a673f91d05de47a89 to your computer and use it in GitHub Desktop.
Save radi-cho/eda1970d6ea7f76a673f91d05de47a89 to your computer and use it in GitHub Desktop.
TF.js linear model inside Firebase Functions
exports.runLinearModel = functions.https.onRequest((request, response) => {
// Get x_test value from the request body
const x_test = Number(request.body.x);
// Check if the x value is number. Otherwise request a valid one and terminate the function.
if (typeof x_test !== "number" || isNaN(x_test))
response.send("Error! Please format your request body.");
// Define a model for linear regression.
const linearModel = tf.sequential();
linearModel.add(tf.layers.dense({ units: 1, inputShape: [1] }));
// Prepare the model for training: Specify the loss and the optimizer.
linearModel.compile({ loss: "meanSquaredError", optimizer: "sgd" });
// Process the Firestore data
db.collection("linear-values").get()
.then(async querySnapshot => {
// Get tensor-like arrays from Firestore
const xs_data = querySnapshot.docs.map(doc => doc.data().x);
const ys_data = querySnapshot.docs.map(doc => doc.data().y);
// Train the model with those arrays
const xs = tf.tensor1d(xs_data);
const ys = tf.tensor1d(ys_data);
await linearModel.fit(xs, ys);
// Make a prediction
const result = await linearModel.predict(tf.tensor2d([x_test], [1, 1]));
const prediction = Array.from(result.dataSync())[0];
// Send the prediction back as a response
response.send(200, prediction);
return true;
}).catch(e => {
response.send("Database error! " + e);
});
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment