Skip to content

Instantly share code, notes, and snippets.

@esshka
Created September 7, 2023 17:53
Show Gist options
  • Save esshka/ff3374f755fea014b51b3d9757f41435 to your computer and use it in GitHub Desktop.
Save esshka/ff3374f755fea014b51b3d9757f41435 to your computer and use it in GitHub Desktop.
NEAT
class Node {
id
value = 0.0
type
constructor(id, type) {
this.id = id
this.type = type
}
sigmoid(x) {
return 1 / (1 + Math.exp(-x))
}
compute(incoming) {
if (this.type === 'input') return this.value
const sum = incoming.reduce((sum, connection) => {
return sum + connection.weight * connection.from.value
} , 0)
this.value = sum
return this.sigmoid(this.value)
}
}
class Connection {
from
to
weight
enabled = true
constructor(from, to, weight) {
this.from = from
this.to = to
this.weight = weight
}
}
class Genome {
nodes = []
connections = []
forward(inputs) {
this.inputNodes.forEach((node, i) => {
node.value = inputs[i]
})
this.nodes.forEach(node => {
const incoming = this.connections.filter(connection => connection.enabled && connection.to === node)
node.compute(incoming)
})
return this.outputNodes.map(node => node.value)
}
get inputNodes() {
return this.nodes.filter(node => node.type === 'input')
}
get outputNodes() {
return this.nodes.filter(node => node.type === 'output')
}
addNode = (nodeType) => {
const nodeId = this.nodes.length
const newNode = new Node(nodeId, nodeType)
this.nodes.push(newNode)
return newNode
}
addConnection(from, to, weight) {
const newConnection = new Connection(from, to, weight)
this.connections.push(newConnection)
return newConnection
}
getRandomNode() {
return this.nodes[Math.floor(Math.random() * this.nodes.length)]
}
getRandomConnection() {
return this.connections[Math.floor(Math.random() * this.connections.length)]
}
}
function getRandomBetween(min, max) {
return Math.random() * (max - min) + min;
}
function createInitialGenome() {
const genome = new Genome()
const input1 = genome.addNode('input')
const input2 = genome.addNode('input')
const output = genome.addNode('output')
genome.addConnection(input1, output, getRandomBetween(-1, 1))
genome.addConnection(input2, output, getRandomBetween(-1, 1))
return genome
}
// Initialize a basic genome with 2 input nodes, 1 output node, and no connections
const genome = createInitialGenome()
const testOutput = genome.forward([1, 0])
console.log('Test output: ', testOutput)
class NEAT {
populationSize = 1
population = []
generation = 0
mutateWeightChance = 0.8
newConnectionChance = 0.05
newNodeChance = 0.03
constructor({populationSize}) {
this.populationSize = populationSize
// Initialize initial population
Array.from({length: populationSize}).forEach(() => {
this.population.push(createInitialGenome())
})
}
mutate(genome) {
if (Math.random() < this.mutateWeightChance) {
const connection = genome.getRandomConnection()
connection.weight += getRandomBetween(-0.5, 0.5)
}
if (Math.random() < this.newConnectionChance) {
const node1 = genome.getRandomNode()
const node2 = genome.getRandomNode()
if (node1.type !== node2.type) {
genome.addConnection(node1, node2, getRandomBetween(-1, 1))
}
}
if (Math.random() < this.newNodeChance) {
const connection = genome.getRandomConnection()
connection.enabled = false
const middleNode = genome.addNode('hidden')
genome.addConnection(connection.from, middleNode, 1)
genome.addConnection(middleNode, connection.to, connection.weight)
}
}
computeFitness(genome) {
// XOR cases
const cases = [
{ inputs: [0, 0], expected: [0] },
{ inputs: [0, 1], expected: [1] },
{ inputs: [1, 0], expected: [1] },
{ inputs: [1, 1], expected: [0] }
];
const totalError = cases.reduce((acc, testCase) => {
const inputs = testCase.inputs;
const expected = testCase.expected;
const output = genome.forward(inputs);
return acc + output.reduce((error, value, index) => {
return error + Math.pow(value - expected[index], 2);
}, 0);
}, 0.0);
// Fitness is inverse of error
return 1 / (1 + totalError);
}
evolve() {
console.log('Evolving...')
// Compute fitness for each genome
const fitnesses = this.population.map(genome => this.computeFitness(genome))
const totalFitness = fitnesses.reduce((acc, fitness) => acc + fitness, 0)
console.log(`Total fitness: ${totalFitness}`)
// Select genomes based on their fitness
const newPopulation = []
while (newPopulation.length < this.populationSize) {
// Roulette wheel selection
const pick = Math.random() * totalFitness;
let current = 0;
for (let i = 0; i < this.population.length; i++) {
const genome = this.population[i];
current += fitnesses[i];
if (current > pick) {
const offspring = this.reproduce(genome);
newPopulation.push(offspring);
break;
}
}
}
this.population = newPopulation;
this.generation++;
}
reproduce(parent) {
// Select two genomes
// Crossover
// Mutate
const offspring = new Genome()
parent.nodes.forEach(node => {
offspring.addNode(node.type)
})
parent.connections.forEach(connection => {
offspring.addConnection(offspring.nodes[connection.from.id], offspring.nodes[connection.to.id], connection.weight)
})
this.mutate(offspring)
return offspring
}
}
function findBest(populationSize, maxGenerations) {
console.log('Finding best genome...')
const neat = new NEAT({populationSize})
const targetFitness = 0.99
let bestGenome
let bestFitness = 0.0
neat.generation = 0
for (let i = 0; i < maxGenerations; i++) {
console.log(`Generation ${neat.generation}`);
neat.evolve();
// Check the best genome in current generation
let fitnesses = neat.population.map(genome => neat.computeFitness(genome));
let currentBestFitness = Math.max(...fitnesses);
let bestGenomeIdx = fitnesses.indexOf(currentBestFitness);
if (currentBestFitness > bestFitness) {
bestFitness = currentBestFitness;
bestGenome = neat.population[bestGenomeIdx];
}
// Break if target fitness achieved
if (bestFitness >= targetFitness) {
break;
}
}
return bestGenome
}
const bestGenome = findBest(100, 100)
console.log('Best genome found: ', bestGenome)
const testCases = [
{inputs: [0, 0], expected: [0]},
{inputs: [0, 1], expected: [1]},
{inputs: [1, 0], expected: [1]},
{inputs: [1, 1], expected: [0]}
];
const predictions = testCases.map(testCase => {
return {
inputs: testCase.inputs,
prediction: bestGenome.forward(testCase.inputs)
};
});
console.log(predictions);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment