Skip to content

Instantly share code, notes, and snippets.

@radi-cho
Created September 23, 2018 19:03
Show Gist options
  • Save radi-cho/f8b69d1f2bd6588d555c10ff49915937 to your computer and use it in GitHub Desktop.
Save radi-cho/f8b69d1f2bd6588d555c10ff49915937 to your computer and use it in GitHub Desktop.
Method which is used to train and save the BoW model.
const trainSave = async (model, querySnapshot) => {
if (!querySnapshot.docs.length) return false;
const xs_data = querySnapshot.docs.map(doc => fitData(doc.data().text));
const ys_data = querySnapshot.docs.map(
doc => (doc.data().y === "positive" ? [1] : [0])
);
const xs = tf.tensor2d(xs_data);
const ys = tf.tensor2d(ys_data);
// train the model
await model.fit(xs, ys, { epochs: 5 });
const modelPath = join(tmpdir, "model");
const tempJSONPath = join(modelPath, "model.json");
const tempBINPath = join(modelPath, "weights.bin");щ
await model.save("file://" + modelPath);
await bucket.upload(tempJSONPath);
await bucket.upload(tempBINPath);
console.log("New files, uploaded.");
// Delete the temporary files
fs.unlinkSync(tempJSONPath);
fs.unlinkSync(tempBINPath);
return true;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment