Skip to content

Instantly share code, notes, and snippets.

@radi-cho
Created September 23, 2018 19:07
Show Gist options
  • Save radi-cho/c315caca4abf418d4437da1e0e9ecd4d to your computer and use it in GitHub Desktop.
Save radi-cho/c315caca4abf418d4437da1e0e9ecd4d to your computer and use it in GitHub Desktop.
Check if training from scratch or retraining an existing model is needed.
// TFjs and gcs are using a lot of memory, so increasing it will speed up the functions
exports.train = functions
.runWith({ memory: "2GB" })
.https.onRequest(async (request, response) => {
const existJSON = await bucket.file("model.json").exists().then(ex => ex[0]);
const existBIN = await bucket.file("weights.bin").exists().then(ex => ex[0]);
if (!existJSON || !existBIN) {
const model = tf.sequential();
model.add(tf.layers.dense({ units: 2, inputShape: [vocabulary.length] }));
model.add(tf.layers.dense({ units: 1, inputShape: [2] }));
model.compile({ loss: "meanSquaredError", optimizer: "sgd" });
await db.collection("comments").get()
.then(async querySnapshot => {
// Get tensor-like arrays from Firestore
return await trainSave(model, querySnapshot);
});
} else {
// Download the files if they exist
await bucket.file("model.json").download({ destination: join(tmpdir, "model.json") });
await bucket.file("weights.bin").download({ destination: join(tmpdir, "weights.bin") });
const model = await tf.loadModel("file://" + join(tmpdir, "model.json"));
model.compile({ loss: "meanSquaredError", optimizer: "sgd" });
const lastUpdated = await bucket.file("weights.bin").getMetadata()
.then(metadata => new Date(metadata[0].updated));
await db.collection("comments").where("publishedAt", ">", lastUpdated).get()
.then(async querySnapshot => {
return await trainSave(model, querySnapshot);
});
}
response.send("Success");
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment