Skip to content

Instantly share code, notes, and snippets.

@Saren-Arterius
Created December 20, 2018 06:58
Show Gist options
  • Save Saren-Arterius/1b3cfff235cd8dc2bef9d8c0a62e222a to your computer and use it in GitHub Desktop.
Save Saren-Arterius/1b3cfff235cd8dc2bef9d8c0a62e222a to your computer and use it in GitHub Desktop.
let trainExp = 7;
let datasetSize = 2000;
let printValidateEvery = 100;
let epoches = 1000;
let validateRatio = 0.2;
let accuracyThreshold = 0.99999;
let shouldPassAccuracyThresholdAtLeastEpoch = 300;
let network;
function evalNN(network, data) {
let [h1r, h2r, h3r, h4r, xh1, xh2, xh3, xh4, yh1, yh2, yh3, yh4] = network;
let [trainX, trainY, trainR] = data;
let [h1Out, h2Out, h3Out, h4Out] = [
(trainX * xh1) + (trainY * yh1),
(trainX * xh2) + (trainY * yh2),
(trainX * xh3) + (trainY * yh3),
(trainX * xh4) + (trainY * yh4)
];
let rOut = (h1r * h1Out) + (h2r * h2Out) + (h3r * h3Out) + (h4r * h4Out);
return [h1Out, h2Out, h3Out, h4Out, rOut];
}
function toRange(x, y) {
let exp = Math.ceil(Math.log10(x + y));
let entry = [x / (10 ** exp), y / (10 ** exp), (x + y) / (10 ** exp)];
return [exp, entry];
}
function trainNN() {
// Generate dataset
let trainDataset = [];
let validateDataset = [];
for (let i = 0; i < datasetSize; i++) {
let x = Math.random() * (10 ** trainExp);
let y = Math.random() * (10 ** trainExp);
let entry = toRange(x, y)[1];
if (i < datasetSize * validateRatio) {
validateDataset.push(entry);
} else {
trainDataset.push(entry);
}
}
// init weights
let [h1r, h2r, h3r, h4r, xh1, xh2, xh3, xh4, yh1, yh2, yh3, yh4] = [
Math.random(), Math.random(), Math.random(), Math.random(),
Math.random(), Math.random(), Math.random(), Math.random(),
Math.random(), Math.random(), Math.random(), Math.random()
];
// las-vegas training
let reinitIfNeeded = (v) => v < -10 ? Math.random() : (v > 10 ? Math.random() : v);
for (let i = 1; i <= epoches; i++) {
for (let data of trainDataset) {
let [h1Out, h2Out, h3Out, h4Out, rOut] = evalNN([h1r, h2r, h3r, h4r, xh1, xh2, xh3, xh4, yh1, yh2, yh3, yh4], data);
let trainR = data[2];
let eR = trainR - rOut;
let [eH1, eH2, eH3, eH4] = [eR * h1r, eR * h2r, eR * h3r, eR * h4r];
[h1r, h2r, h3r, h4r, xh1, xh2, xh3, xh4, yh1, yh2, yh3, yh4] = [
reinitIfNeeded((eR * rOut) + h1r), reinitIfNeeded((eR * rOut) + h2r), reinitIfNeeded((eR * rOut) + h3r), reinitIfNeeded((eR * rOut) + h4r),
reinitIfNeeded((eH1 * h1Out) + xh1), reinitIfNeeded((eH2 * h2Out) + xh2), reinitIfNeeded((eH3 * h3Out) + xh3), reinitIfNeeded((eH4 * h4Out) + xh4),
reinitIfNeeded((eH1 * h1Out) + yh1), reinitIfNeeded((eH2 * h2Out) + yh2), reinitIfNeeded((eH3 * h3Out) + yh3), reinitIfNeeded((eH4 * h4Out) + yh4),
];
}
// validate
if (i % printValidateEvery === 0) {
let diffs = [];
for (let data of validateDataset) {
let [h1Out, h2Out, h3Out, h4Out, rOut] = evalNN([h1r, h2r, h3r, h4r, xh1, xh2, xh3, xh4, yh1, yh2, yh3, yh4], data);
let trainR = data[2];
diffs.push(Math.abs(trainR - rOut));
}
let sum = diffs.reduce((a, b) => a + b);
let avg = sum / diffs.length;
let accuracy = 1 - Math.abs(avg);
console.log(`Epoch ${i}, Accuracy: ${1 - Math.abs(avg)}`);
if (i >= shouldPassAccuracyThresholdAtLeastEpoch && accuracy < accuracyThreshold) {
return trainNN(); // who cares
}
}
}
return [h1r, h2r, h3r, h4r, xh1, xh2, xh3, xh4, yh1, yh2, yh3, yh4];
}
function add(x, y) {
if (!network) {
network = trainNN();
}
let [exp, range] = toRange(x, y);
let r = evalNN(network, range)[4];
return Math.round(r * (10 ** exp));
}
console.log(add(1, 2));
console.log(add(10, 10));
console.log(add(111, 222));
console.log(add(999, 999));
console.log(add(123456789, 987654321));
console.log(network);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment