Created
March 8, 2023 10:46
-
-
Save esshka/4c960a0171822de23506f42c3d28073c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const learningRate = 0.3; | |
class Sigmoid { | |
static calculate(x) { | |
return 1 / (1 + Math.exp(-x)); | |
} | |
static derivative(x) { | |
const fx = Sigmoid.calculate(x); | |
return fx * (1 - fx); | |
} | |
} | |
class Node { | |
incoming = []; | |
outcoming = []; | |
bias = 0; | |
output; | |
outputDerivative; | |
error; | |
weightedInputSum = () => { | |
return this.incoming.reduce((acc, conn) => { | |
return acc + conn.weight * conn.from.output; | |
}, this.bias); | |
}; | |
weightedErrorSum = () => { | |
return this.outcoming.reduce((acc, conn) => { | |
conn.setWeight(conn.weight - learningRate * conn.to.error * this.output); | |
const error = conn.weight * conn.to.error; | |
return acc + error; | |
}, 0); | |
}; | |
forward = (x) => { | |
if (x !== undefined && x !== null) { | |
this.output = x; | |
this.outputDerivative = 1; | |
return; | |
} | |
const sum = this.weightedInputSum(); | |
this.output = Sigmoid.calculate(sum); | |
this.outputDerivative = Sigmoid.derivative(sum); | |
}; | |
backpropagate = (desiredOutput) => { | |
const hasDesiredOutput = | |
desiredOutput !== undefined && desiredOutput !== null; | |
const sum = hasDesiredOutput | |
? this.output - desiredOutput | |
: this.weightedErrorSum(); | |
this.error = sum * this.outputDerivative; | |
const biasDelta = this.error * learningRate; | |
this.bias = this.bias - biasDelta; | |
}; | |
connectTo(to) { | |
const conn = new Connection(this, to); | |
to.incoming.push(conn); | |
this.outcoming.push(conn); | |
} | |
} | |
class Connection { | |
constructor(from, to) { | |
this.from = from; | |
this.to = to; | |
this.weight = Math.random() * 2 - 1; | |
} | |
setWeight = (x) => { | |
this.weight = x; | |
}; | |
} | |
class Layer { | |
constructor(numNodes) { | |
this.numNodes = numNodes; | |
this.buildNodes(); | |
} | |
buildNodes = () => { | |
this.nodes = new Array(this.numNodes).fill(null).map(() => { | |
return new Node(); | |
}); | |
}; | |
forward = (inputs) => { | |
this.nodes.forEach((n, i) => { | |
n.forward(inputs ? inputs[i] : null); | |
}); | |
}; | |
backpropagate(desiredOutputs) { | |
this.nodes.forEach((n, i) => { | |
n.backpropagate(desiredOutputs ? desiredOutputs[i] : null); | |
}); | |
} | |
connectTo = (layer) => { | |
this.nodes.forEach((n1) => { | |
layer.nodes.forEach((n2) => { | |
n1.connectTo(n2); | |
}); | |
}); | |
}; | |
} | |
class Network { | |
constructor(numInput, numHidden, numOutput) { | |
this.inputLayer = new Layer(numInput); | |
this.hiddenLayers = [new Layer(numHidden)]; | |
this.outputLayer = new Layer(numOutput); | |
this.layers = [this.inputLayer, ...this.hiddenLayers, this.outputLayer]; | |
this.connectLayers(); | |
} | |
connectLayers = () => { | |
this.layers.forEach((l, i) => { | |
const nextLayer = this.layers[i + 1]; | |
if (!nextLayer) return; | |
l.connectTo(nextLayer); | |
}); | |
}; | |
forward = (inputs) => { | |
if (inputs.length !== this.inputLayer.nodes.length) return; | |
this.inputLayer.forward(inputs); | |
this.hiddenLayers.forEach((l) => l.forward()); | |
this.outputLayer.forward(); | |
return this.outputLayer.nodes.map((n) => n.output); | |
}; | |
backpropagate = (desiredOutputs) => { | |
this.outputLayer.backpropagate(desiredOutputs); | |
this.hiddenLayers.forEach((l) => l.backpropagate()); | |
this.inputLayer.backpropagate(); | |
}; | |
train = (dataset, numIterations) => { | |
for (let i = 0; i < numIterations; i++) { | |
for (let j = 0; j < dataset.length; j++) { | |
const input = dataset[j].input; | |
const targetOutput = dataset[j].output; | |
this.forward(input); | |
this.backpropagate(targetOutput); | |
} | |
} | |
}; | |
getValues = () => { | |
return this.layers.map((l) => { | |
return l.nodes.map((n) => n.output); | |
}); | |
}; | |
getWeights = () => { | |
return this.layers.map((l) => { | |
return l.nodes.map((n) => n.outcoming.map((c) => c.weight)); | |
}); | |
}; | |
} | |
const nn = new Network(2, 2, 1); | |
const dataset = [ | |
{ input: [0, 0], output: [0] }, | |
{ input: [0, 1], output: [0] }, | |
{ input: [1, 0], output: [1] }, | |
{ input: [1, 1], output: [1] }, | |
]; | |
nn.train(dataset, 10000); | |
console.log(nn.forward([0, 0])); // Output should be close to 0 | |
console.log(nn.forward([0, 1])); // Output should be close to 0 | |
console.log(nn.forward([1, 0])); // Output should be close to 0 | |
console.log(nn.forward([1, 1])); // Output should be close to 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment