Skip to content

Instantly share code, notes, and snippets.

@bellbind
Last active April 10, 2018 08:21
Show Gist options
  • Save bellbind/52c049c07dbd0ddc615f4e32109421bf to your computer and use it in GitHub Desktop.
Save bellbind/52c049c07dbd0ddc615f4e32109421bf to your computer and use it in GitHub Desktop.
[javascript][tensorflowjs] train XOR function with tensorflow.js
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<script src="script.js" defer="defer"></script>
</head>
<body></body>
</html>
// `tf`: root object of tensorflow.js
// build keras model
const model = tf.sequential(); // factory function for keras Sequential obj
model.add(tf.layers.dense({units: 4, inputDim: 2, activation: "relu"}));
model.add(tf.layers.dense({units: 2, activation: "softmax"}));
// names for keras are turned as camelCase string
model.compile({loss: "categoricalCrossentropy", optimizer: "sgd"});
// tensor data: similar as numpy array
const inputs = [[0,0], [0, 1], [1, 0], [1, 1]];
const outputs = inputs.map(([a, b]) => a ^ b);
const xs = tf.tensor(inputs);
const ys = tf.oneHot(tf.tensor(outputs), 2);
// train and check
(async () => {
const tx = tf.tensor([[1, 1], [0, 1], [0, 0], [1, 0]]);
performance.mark(`start`);
for (let i = 0; i < 100; i++) {
performance.mark(`start-fit-${i}`);
const r = await model.fit(xs, ys, {epochs: 100});
performance.mark(`end-fit-${i}`);
performance.measure(`fit-${i}`, `start-fit-${i}`, `end-fit-${i}`);
const ty = model.predict(tx);
await display(i, tx, ty);
await new Promise(r => setTimeout(r, 10));
}
performance.mark(`end`);
performance.measure(`total`, `start`, `end`);
await displayModel(model);
await displayWeights(model);
await displayDuration();
})().catch(console.error);
// outputs
async function display(i, tx, ty) {
const [xs, ys, vs] = await Promise.all(
[tx.data(), ty.data(), ty.argMax(1).data()]);
const ms = performance.getEntriesByName(`fit-${i}`)[0].duration;
const ls = [...vs].map((v, i) => {
const l = 2 * i, r = l + 1;
return `${xs[l]} ^ ${xs[r]} = ${v} (0=${ys[l]}, 1=${ys[r]})\n`;
});
const pre = document.createElement("pre");
pre.textContent = `[Train XOR: ${i + 1} (${ms}ms)]\n` + ls.join("");
document.body.insertBefore(pre, document.body.firstChild);
}
async function displayDuration() {
const ms = performance.getEntriesByName(`total`)[0].duration;
const pre = document.createElement("pre");
pre.textContent = `(total: ${ms / 1000}sec)`;
document.body.insertBefore(pre, document.body.firstChild);
}
async function displayModel(model) {
const pre = document.createElement("pre");
pre.textContent = JSON.stringify(JSON.parse(model.toJSON()), null, 2);
document.body.insertBefore(pre, document.body.firstChild);
}
async function displayWeights(model) {
const ls = model.weights.map(({name, shape, val}) => {
return JSON.stringify({name, shape, values: [...val.dataSync()]});
});
const pre = document.createElement("pre");
pre.textContent = ls.join("\n");
document.body.insertBefore(pre, document.body.firstChild);
}
@bellbind
Copy link
Author

bellbind commented Apr 5, 2018

@bellbind
Copy link
Author

bellbind commented Apr 6, 2018

NOTE:

  • slow (and maybe bad result on fit()) when tab hidden
  • Warned on Safari at tf.oneHot(), then fit() fails to change wieghts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment