Created
February 9, 2024 16:13
-
-
Save esshka/67bc0b56df34609222983e43b736dbc6 to your computer and use it in GitHub Desktop.
pole balancing benchmark for genetic algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const MAX_TIMESTEPS = 1000; | |
function initializeCartPoleEnvironment() { | |
const gravity = 9.8; // Acceleration due to gravity, m/s^2 | |
const cartMass = 1.0; // Mass of the cart | |
const poleMass = 0.1; // Mass of the pole | |
const totalMass = cartMass + poleMass; | |
const length = 0.5; // Half the pole's length | |
const poleMassLength = poleMass * length; | |
const forceMagnitude = 10.0; // Magnitude of the force applied to the cart | |
const tau = 0.02; // Time step for simulation, seconds | |
const state = { | |
cartPosition: 0, | |
cartVelocity: 0, | |
poleAngle: (Math.PI / 180) * Math.random(), // Slightly tilted pole at start | |
poleAngularVelocity: 0, | |
}; | |
function getState() { | |
return [ | |
state.cartPosition, | |
state.cartVelocity, | |
state.poleAngle, | |
state.poleAngularVelocity, | |
]; | |
} | |
function step(action) { | |
// Calculate the force applied to the cart | |
const force = action === 1 ? forceMagnitude : -forceMagnitude; | |
// Equations of motion for the cart and pole (simplified) | |
const cosTheta = Math.cos(state.poleAngle); | |
const sinTheta = Math.sin(state.poleAngle); | |
const temp = | |
(force + poleMassLength * state.poleAngularVelocity ** 2 * sinTheta) / | |
totalMass; | |
const angularAcceleration = | |
(gravity * sinTheta - cosTheta * temp) / | |
(length * (4.0 / 3.0 - (poleMass * cosTheta ** 2) / totalMass)); | |
const linearAcceleration = | |
temp - (poleMassLength * angularAcceleration * cosTheta) / totalMass; | |
// Update the state using the equations of motion | |
state.cartPosition += tau * state.cartVelocity; | |
state.cartVelocity += tau * linearAcceleration; | |
state.poleAngle += tau * state.poleAngularVelocity; | |
state.poleAngularVelocity += tau * angularAcceleration; | |
// Check if the pole is still balanced | |
const done = | |
state.poleAngle < -Math.PI / 2 || state.poleAngle > Math.PI / 2; | |
return done; | |
} | |
function poleFell() { | |
// Check if the pole has fallen over | |
return state.poleAngle < -Math.PI / 2 || state.poleAngle > Math.PI / 2; | |
} | |
// Return the interface to the environment | |
return { getState, step, poleFell }; | |
} | |
export function runPoleBench() { | |
function simulateCartPole(network) { | |
let fitness = 0; | |
// Initialize your cart-pole simulation environment here | |
// For simplicity, assume we have a function that initializes the environment | |
// and returns an object with methods to step (update) the environment based on the network's action | |
// and to check if the pole is still balanced. | |
const environment = initializeCartPoleEnvironment(); | |
let done = false; | |
while (!done) { | |
const input = environment.getState(); | |
const action = Math.round(activateNetwork(network, input)); // Assume this returns a discrete action (e.g., 0 for left, 1 for right) | |
done = environment.step(action); // Update the environment with the chosen action | |
fitness += 1; // Increment fitness for each timestep the pole remains balanced | |
if (environment.poleFell() || fitness >= MAX_TIMESTEPS) { | |
break; | |
} | |
} | |
return fitness; | |
} | |
function fitnessFunction(genome, growth) { | |
let fitness = simulateCartPole(genome); | |
fitness -= | |
(genome.nodes.length - | |
genome.inputSize - | |
genome.outputSize + | |
genome.connections.length) * | |
growth; | |
return fitness; | |
} | |
const populationSize = 100; | |
const elitism = 10; | |
const mutationRate = 0.3; | |
const growth = 0.0001; | |
const maxGenerations = 1000; | |
const targetFitness = 1000; | |
const networkTemplate = buildNetwork(4, 1); | |
let population = createPopulation({ | |
populationSize, | |
networkTemplate, | |
}); | |
let generation = 0; | |
let bestFitness = 0; | |
let bestGenome = null; | |
while (generation < maxGenerations && bestFitness < targetFitness) { | |
population.forEach((genome) => { | |
const fitness = fitnessFunction(genome, growth); | |
genome.score = fitness; | |
if (fitness > bestFitness) { | |
bestGenome = copyNetwork(genome); | |
bestFitness = fitness; | |
} | |
}); | |
const sortedByFitness = sortPopulation(population); | |
const elites = sortedByFitness.slice(0, elitism); | |
const newPopulation = [...elites]; | |
while (newPopulation.length < populationSize) { | |
const offspring = getOffspring(sortedByFitness); | |
newPopulation.push(offspring); | |
} | |
mutatePopulation(newPopulation.slice(elitism), { | |
mutationRate: mutationRate, | |
}); | |
population = newPopulation; | |
generation++; | |
} | |
console.log("Evolution complete"); | |
console.log(`Best fitness achieved: ${bestFitness}`); | |
console.log(`Generations: ${generation}`); | |
console.log(`Best Genome nodes: ${bestGenome.nodes.length}`); | |
console.log(`Best Genome conns: ${bestGenome.connections.length}`); | |
return { | |
bestGenome, | |
}; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment