Created
September 7, 2023 17:53
-
-
Save esshka/ff3374f755fea014b51b3d9757f41435 to your computer and use it in GitHub Desktop.
NEAT
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Node { | |
id | |
value = 0.0 | |
type | |
constructor(id, type) { | |
this.id = id | |
this.type = type | |
} | |
sigmoid(x) { | |
return 1 / (1 + Math.exp(-x)) | |
} | |
compute(incoming) { | |
if (this.type === 'input') return this.value | |
const sum = incoming.reduce((sum, connection) => { | |
return sum + connection.weight * connection.from.value | |
} , 0) | |
this.value = sum | |
return this.sigmoid(this.value) | |
} | |
} | |
class Connection { | |
from | |
to | |
weight | |
enabled = true | |
constructor(from, to, weight) { | |
this.from = from | |
this.to = to | |
this.weight = weight | |
} | |
} | |
class Genome { | |
nodes = [] | |
connections = [] | |
forward(inputs) { | |
this.inputNodes.forEach((node, i) => { | |
node.value = inputs[i] | |
}) | |
this.nodes.forEach(node => { | |
const incoming = this.connections.filter(connection => connection.enabled && connection.to === node) | |
node.compute(incoming) | |
}) | |
return this.outputNodes.map(node => node.value) | |
} | |
get inputNodes() { | |
return this.nodes.filter(node => node.type === 'input') | |
} | |
get outputNodes() { | |
return this.nodes.filter(node => node.type === 'output') | |
} | |
addNode = (nodeType) => { | |
const nodeId = this.nodes.length | |
const newNode = new Node(nodeId, nodeType) | |
this.nodes.push(newNode) | |
return newNode | |
} | |
addConnection(from, to, weight) { | |
const newConnection = new Connection(from, to, weight) | |
this.connections.push(newConnection) | |
return newConnection | |
} | |
getRandomNode() { | |
return this.nodes[Math.floor(Math.random() * this.nodes.length)] | |
} | |
getRandomConnection() { | |
return this.connections[Math.floor(Math.random() * this.connections.length)] | |
} | |
} | |
function getRandomBetween(min, max) { | |
return Math.random() * (max - min) + min; | |
} | |
function createInitialGenome() { | |
const genome = new Genome() | |
const input1 = genome.addNode('input') | |
const input2 = genome.addNode('input') | |
const output = genome.addNode('output') | |
genome.addConnection(input1, output, getRandomBetween(-1, 1)) | |
genome.addConnection(input2, output, getRandomBetween(-1, 1)) | |
return genome | |
} | |
// Initialize a basic genome with 2 input nodes, 1 output node, and no connections | |
const genome = createInitialGenome() | |
const testOutput = genome.forward([1, 0]) | |
console.log('Test output: ', testOutput) | |
class NEAT { | |
populationSize = 1 | |
population = [] | |
generation = 0 | |
mutateWeightChance = 0.8 | |
newConnectionChance = 0.05 | |
newNodeChance = 0.03 | |
constructor({populationSize}) { | |
this.populationSize = populationSize | |
// Initialize initial population | |
Array.from({length: populationSize}).forEach(() => { | |
this.population.push(createInitialGenome()) | |
}) | |
} | |
mutate(genome) { | |
if (Math.random() < this.mutateWeightChance) { | |
const connection = genome.getRandomConnection() | |
connection.weight += getRandomBetween(-0.5, 0.5) | |
} | |
if (Math.random() < this.newConnectionChance) { | |
const node1 = genome.getRandomNode() | |
const node2 = genome.getRandomNode() | |
if (node1.type !== node2.type) { | |
genome.addConnection(node1, node2, getRandomBetween(-1, 1)) | |
} | |
} | |
if (Math.random() < this.newNodeChance) { | |
const connection = genome.getRandomConnection() | |
connection.enabled = false | |
const middleNode = genome.addNode('hidden') | |
genome.addConnection(connection.from, middleNode, 1) | |
genome.addConnection(middleNode, connection.to, connection.weight) | |
} | |
} | |
computeFitness(genome) { | |
// XOR cases | |
const cases = [ | |
{ inputs: [0, 0], expected: [0] }, | |
{ inputs: [0, 1], expected: [1] }, | |
{ inputs: [1, 0], expected: [1] }, | |
{ inputs: [1, 1], expected: [0] } | |
]; | |
const totalError = cases.reduce((acc, testCase) => { | |
const inputs = testCase.inputs; | |
const expected = testCase.expected; | |
const output = genome.forward(inputs); | |
return acc + output.reduce((error, value, index) => { | |
return error + Math.pow(value - expected[index], 2); | |
}, 0); | |
}, 0.0); | |
// Fitness is inverse of error | |
return 1 / (1 + totalError); | |
} | |
evolve() { | |
console.log('Evolving...') | |
// Compute fitness for each genome | |
const fitnesses = this.population.map(genome => this.computeFitness(genome)) | |
const totalFitness = fitnesses.reduce((acc, fitness) => acc + fitness, 0) | |
console.log(`Total fitness: ${totalFitness}`) | |
// Select genomes based on their fitness | |
const newPopulation = [] | |
while (newPopulation.length < this.populationSize) { | |
// Roulette wheel selection | |
const pick = Math.random() * totalFitness; | |
let current = 0; | |
for (let i = 0; i < this.population.length; i++) { | |
const genome = this.population[i]; | |
current += fitnesses[i]; | |
if (current > pick) { | |
const offspring = this.reproduce(genome); | |
newPopulation.push(offspring); | |
break; | |
} | |
} | |
} | |
this.population = newPopulation; | |
this.generation++; | |
} | |
reproduce(parent) { | |
// Select two genomes | |
// Crossover | |
// Mutate | |
const offspring = new Genome() | |
parent.nodes.forEach(node => { | |
offspring.addNode(node.type) | |
}) | |
parent.connections.forEach(connection => { | |
offspring.addConnection(offspring.nodes[connection.from.id], offspring.nodes[connection.to.id], connection.weight) | |
}) | |
this.mutate(offspring) | |
return offspring | |
} | |
} | |
function findBest(populationSize, maxGenerations) { | |
console.log('Finding best genome...') | |
const neat = new NEAT({populationSize}) | |
const targetFitness = 0.99 | |
let bestGenome | |
let bestFitness = 0.0 | |
neat.generation = 0 | |
for (let i = 0; i < maxGenerations; i++) { | |
console.log(`Generation ${neat.generation}`); | |
neat.evolve(); | |
// Check the best genome in current generation | |
let fitnesses = neat.population.map(genome => neat.computeFitness(genome)); | |
let currentBestFitness = Math.max(...fitnesses); | |
let bestGenomeIdx = fitnesses.indexOf(currentBestFitness); | |
if (currentBestFitness > bestFitness) { | |
bestFitness = currentBestFitness; | |
bestGenome = neat.population[bestGenomeIdx]; | |
} | |
// Break if target fitness achieved | |
if (bestFitness >= targetFitness) { | |
break; | |
} | |
} | |
return bestGenome | |
} | |
const bestGenome = findBest(100, 100) | |
console.log('Best genome found: ', bestGenome) | |
const testCases = [ | |
{inputs: [0, 0], expected: [0]}, | |
{inputs: [0, 1], expected: [1]}, | |
{inputs: [1, 0], expected: [1]}, | |
{inputs: [1, 1], expected: [0]} | |
]; | |
const predictions = testCases.map(testCase => { | |
return { | |
inputs: testCase.inputs, | |
prediction: bestGenome.forward(testCase.inputs) | |
}; | |
}); | |
console.log(predictions); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment