Skip to content

Instantly share code, notes, and snippets.

@esshka
Created March 8, 2023 10:46
Show Gist options
  • Save esshka/4c960a0171822de23506f42c3d28073c to your computer and use it in GitHub Desktop.
Save esshka/4c960a0171822de23506f42c3d28073c to your computer and use it in GitHub Desktop.
const learningRate = 0.3;
class Sigmoid {
static calculate(x) {
return 1 / (1 + Math.exp(-x));
}
static derivative(x) {
const fx = Sigmoid.calculate(x);
return fx * (1 - fx);
}
}
class Node {
incoming = [];
outcoming = [];
bias = 0;
output;
outputDerivative;
error;
weightedInputSum = () => {
return this.incoming.reduce((acc, conn) => {
return acc + conn.weight * conn.from.output;
}, this.bias);
};
weightedErrorSum = () => {
return this.outcoming.reduce((acc, conn) => {
conn.setWeight(conn.weight - learningRate * conn.to.error * this.output);
const error = conn.weight * conn.to.error;
return acc + error;
}, 0);
};
forward = (x) => {
if (x !== undefined && x !== null) {
this.output = x;
this.outputDerivative = 1;
return;
}
const sum = this.weightedInputSum();
this.output = Sigmoid.calculate(sum);
this.outputDerivative = Sigmoid.derivative(sum);
};
backpropagate = (desiredOutput) => {
const hasDesiredOutput =
desiredOutput !== undefined && desiredOutput !== null;
const sum = hasDesiredOutput
? this.output - desiredOutput
: this.weightedErrorSum();
this.error = sum * this.outputDerivative;
const biasDelta = this.error * learningRate;
this.bias = this.bias - biasDelta;
};
connectTo(to) {
const conn = new Connection(this, to);
to.incoming.push(conn);
this.outcoming.push(conn);
}
}
class Connection {
constructor(from, to) {
this.from = from;
this.to = to;
this.weight = Math.random() * 2 - 1;
}
setWeight = (x) => {
this.weight = x;
};
}
class Layer {
constructor(numNodes) {
this.numNodes = numNodes;
this.buildNodes();
}
buildNodes = () => {
this.nodes = new Array(this.numNodes).fill(null).map(() => {
return new Node();
});
};
forward = (inputs) => {
this.nodes.forEach((n, i) => {
n.forward(inputs ? inputs[i] : null);
});
};
backpropagate(desiredOutputs) {
this.nodes.forEach((n, i) => {
n.backpropagate(desiredOutputs ? desiredOutputs[i] : null);
});
}
connectTo = (layer) => {
this.nodes.forEach((n1) => {
layer.nodes.forEach((n2) => {
n1.connectTo(n2);
});
});
};
}
class Network {
constructor(numInput, numHidden, numOutput) {
this.inputLayer = new Layer(numInput);
this.hiddenLayers = [new Layer(numHidden)];
this.outputLayer = new Layer(numOutput);
this.layers = [this.inputLayer, ...this.hiddenLayers, this.outputLayer];
this.connectLayers();
}
connectLayers = () => {
this.layers.forEach((l, i) => {
const nextLayer = this.layers[i + 1];
if (!nextLayer) return;
l.connectTo(nextLayer);
});
};
forward = (inputs) => {
if (inputs.length !== this.inputLayer.nodes.length) return;
this.inputLayer.forward(inputs);
this.hiddenLayers.forEach((l) => l.forward());
this.outputLayer.forward();
return this.outputLayer.nodes.map((n) => n.output);
};
backpropagate = (desiredOutputs) => {
this.outputLayer.backpropagate(desiredOutputs);
this.hiddenLayers.forEach((l) => l.backpropagate());
this.inputLayer.backpropagate();
};
train = (dataset, numIterations) => {
for (let i = 0; i < numIterations; i++) {
for (let j = 0; j < dataset.length; j++) {
const input = dataset[j].input;
const targetOutput = dataset[j].output;
this.forward(input);
this.backpropagate(targetOutput);
}
}
};
getValues = () => {
return this.layers.map((l) => {
return l.nodes.map((n) => n.output);
});
};
getWeights = () => {
return this.layers.map((l) => {
return l.nodes.map((n) => n.outcoming.map((c) => c.weight));
});
};
}
const nn = new Network(2, 2, 1);
const dataset = [
{ input: [0, 0], output: [0] },
{ input: [0, 1], output: [0] },
{ input: [1, 0], output: [1] },
{ input: [1, 1], output: [1] },
];
nn.train(dataset, 10000);
console.log(nn.forward([0, 0])); // Output should be close to 0
console.log(nn.forward([0, 1])); // Output should be close to 0
console.log(nn.forward([1, 0])); // Output should be close to 0
console.log(nn.forward([1, 1])); // Output should be close to 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment