Last active
May 12, 2020 15:19
-
-
Save rezonn/2643d71e03960dd78483a104516db8e7 to your computer and use it in GitHub Desktop.
TensorFlow.js trainable function javascript
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| <html> | |
| <head> | |
| <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script> | |
| <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js"></script> | |
| <script> | |
| function plot(x,y,z) { | |
| var [x_,y_,z_] = [x,y,z||y].map( (t)=>t.arraySync() ); | |
| var p = x_.map(function(x__,i){return {x:x__,y:y_[i]} }); | |
| var p1 = x_.map(function(x__,i){return {x:x__,y:z_[i]} }); | |
| tfvis.render.scatterplot({name:"",tab:'Charts'},{values:[p,p1]}); | |
| } | |
| window.onload = async function() { | |
| // Trainable function | |
| const f = (x,k) => k[0].mul(x.square()).add(k[1].mul(x)).add(k[2]); | |
| // Hidden parameters | |
| const k0 =[3.2,1.7,0.7].map( (m)=>tf.scalar(m).variable() ); | |
| // Train data | |
| const xs = tf.tensor1d([0, 1, 2, 3]); | |
| const ys = f(xs,k0); | |
| // Random parameters | |
| const k = Array(3).fill(0).map( ()=>tf.scalar(Math.random()).variable() ); | |
| // Fit random parameters | |
| const loss = (pred, label) => pred.sub(label).square().mean(); | |
| const learningRate = 0.01; | |
| const optimizer = tf.train.sgd(learningRate); | |
| for (let i=0; i<10; i++) { | |
| optimizer.minimize(() => loss(f(xs,k), ys)); | |
| } | |
| plot(xs,ys,f(xs,k)); | |
| } | |
| </script> | |
| </head> | |
| <body></body> | |
| </html> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment