Created
September 23, 2018 19:03
-
-
Save radi-cho/f8b69d1f2bd6588d555c10ff49915937 to your computer and use it in GitHub Desktop.
Method which is used to train and save the BoW model.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const trainSave = async (model, querySnapshot) => { | |
if (!querySnapshot.docs.length) return false; | |
const xs_data = querySnapshot.docs.map(doc => fitData(doc.data().text)); | |
const ys_data = querySnapshot.docs.map( | |
doc => (doc.data().y === "positive" ? [1] : [0]) | |
); | |
const xs = tf.tensor2d(xs_data); | |
const ys = tf.tensor2d(ys_data); | |
// train the model | |
await model.fit(xs, ys, { epochs: 5 }); | |
const modelPath = join(tmpdir, "model"); | |
const tempJSONPath = join(modelPath, "model.json"); | |
const tempBINPath = join(modelPath, "weights.bin");щ | |
await model.save("file://" + modelPath); | |
await bucket.upload(tempJSONPath); | |
await bucket.upload(tempBINPath); | |
console.log("New files, uploaded."); | |
// Delete the temporary files | |
fs.unlinkSync(tempJSONPath); | |
fs.unlinkSync(tempBINPath); | |
return true; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment