Skip to content

Instantly share code, notes, and snippets.

@rezonn
Last active May 12, 2020 15:19
Show Gist options
  • Select an option

  • Save rezonn/2643d71e03960dd78483a104516db8e7 to your computer and use it in GitHub Desktop.

Select an option

Save rezonn/2643d71e03960dd78483a104516db8e7 to your computer and use it in GitHub Desktop.
TensorFlow.js trainable function javascript
<html>
<head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js"></script>
<script>
function plot(x,y,z) {
var [x_,y_,z_] = [x,y,z||y].map( (t)=>t.arraySync() );
var p = x_.map(function(x__,i){return {x:x__,y:y_[i]} });
var p1 = x_.map(function(x__,i){return {x:x__,y:z_[i]} });
tfvis.render.scatterplot({name:"",tab:'Charts'},{values:[p,p1]});
}
window.onload = async function() {
// Trainable function
const f = (x,k) => k[0].mul(x.square()).add(k[1].mul(x)).add(k[2]);
// Hidden parameters
const k0 =[3.2,1.7,0.7].map( (m)=>tf.scalar(m).variable() );
// Train data
const xs = tf.tensor1d([0, 1, 2, 3]);
const ys = f(xs,k0);
// Random parameters
const k = Array(3).fill(0).map( ()=>tf.scalar(Math.random()).variable() );
// Fit random parameters
const loss = (pred, label) => pred.sub(label).square().mean();
const learningRate = 0.01;
const optimizer = tf.train.sgd(learningRate);
for (let i=0; i<10; i++) {
optimizer.minimize(() => loss(f(xs,k), ys));
}
plot(xs,ys,f(xs,k));
}
</script>
</head>
<body></body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment