Last active
February 26, 2019 21:22
-
-
Save trygvea/978037b0f8fd9abd3e97dc35eb99a262 to your computer and use it in GitHub Desktop.
Machine learning - learn to add
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import * as tf from '@tensorflow/tfjs' | |
/* | |
* Experiment with | |
* - multiplication, combined (ie x1+3*x2). May have to train 500 epochs to get satisfactory results. | |
*/ | |
const randomInt = (max: number) => | |
Math.floor(Math.random() * Math.floor(max)) | |
const randomInts = (max: number, n: number): number[] => | |
new Array(n).fill(0).map(_ => randomInt(max)) | |
const unknown_function = (x1: number, x2: number) => | |
x1 + x2 | |
const generateSamples = (max: number, n: number) => { | |
const x1s = randomInts(max, n) | |
const x2s = randomInts(max, n) | |
return { | |
xs: x1s.map((_, i) => [x1s[i], x2s[i]]), | |
ys: x1s.map((_, i) => unknown_function(x1s[i], x2s[i])), | |
} | |
} | |
type Data = number[] | |
type Observation = number | |
type Prediction = Observation | |
async function learn(xs: Data[], ys: Observation[]) { | |
const model = tf.sequential({ | |
layers: [ | |
tf.layers.dense({ units: 16, activation: 'relu6', inputShape: [2] }), | |
tf.layers.dense({ units: 16, activation: 'relu6' }), | |
tf.layers.dense({ units: 1}), | |
], | |
}) | |
model.compile({ optimizer: 'adam', loss: 'meanSquaredError' }) | |
await model.fit(tf.tensor2d(xs, [xs.length, 2]), tf.tensor2d(ys, [ys.length, 1]), { epochs: 50 }) | |
return (x: Data): Prediction => { | |
// use xs and ys to predict outcome of x | |
const tx = tf.tensor2d(x, [1,2]) | |
const f = model.predict(tx) | |
return model.predict(tx).dataSync()[0] | |
} | |
} | |
const { xs, ys } = generateSamples(10, 500) | |
console.log('learning...') | |
learn(xs, ys).then(predict => { | |
console.log('finished learning.') | |
console.log('#### 1+3 = ', predict([1, 3])) | |
console.log('#### 5+2 = ', predict([5, 2])) | |
console.log('#### 6+7 = ', predict([6, 7])) | |
console.log('#### 9+9 = ', predict([9, 9])) | |
}) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment