Skip to content

Instantly share code, notes, and snippets.

@esshka
Created February 9, 2024 16:13
Show Gist options
  • Save esshka/67bc0b56df34609222983e43b736dbc6 to your computer and use it in GitHub Desktop.
Save esshka/67bc0b56df34609222983e43b736dbc6 to your computer and use it in GitHub Desktop.
pole balancing benchmark for genetic algorithm
const MAX_TIMESTEPS = 1000;
function initializeCartPoleEnvironment() {
const gravity = 9.8; // Acceleration due to gravity, m/s^2
const cartMass = 1.0; // Mass of the cart
const poleMass = 0.1; // Mass of the pole
const totalMass = cartMass + poleMass;
const length = 0.5; // Half the pole's length
const poleMassLength = poleMass * length;
const forceMagnitude = 10.0; // Magnitude of the force applied to the cart
const tau = 0.02; // Time step for simulation, seconds
const state = {
cartPosition: 0,
cartVelocity: 0,
poleAngle: (Math.PI / 180) * Math.random(), // Slightly tilted pole at start
poleAngularVelocity: 0,
};
function getState() {
return [
state.cartPosition,
state.cartVelocity,
state.poleAngle,
state.poleAngularVelocity,
];
}
function step(action) {
// Calculate the force applied to the cart
const force = action === 1 ? forceMagnitude : -forceMagnitude;
// Equations of motion for the cart and pole (simplified)
const cosTheta = Math.cos(state.poleAngle);
const sinTheta = Math.sin(state.poleAngle);
const temp =
(force + poleMassLength * state.poleAngularVelocity ** 2 * sinTheta) /
totalMass;
const angularAcceleration =
(gravity * sinTheta - cosTheta * temp) /
(length * (4.0 / 3.0 - (poleMass * cosTheta ** 2) / totalMass));
const linearAcceleration =
temp - (poleMassLength * angularAcceleration * cosTheta) / totalMass;
// Update the state using the equations of motion
state.cartPosition += tau * state.cartVelocity;
state.cartVelocity += tau * linearAcceleration;
state.poleAngle += tau * state.poleAngularVelocity;
state.poleAngularVelocity += tau * angularAcceleration;
// Check if the pole is still balanced
const done =
state.poleAngle < -Math.PI / 2 || state.poleAngle > Math.PI / 2;
return done;
}
function poleFell() {
// Check if the pole has fallen over
return state.poleAngle < -Math.PI / 2 || state.poleAngle > Math.PI / 2;
}
// Return the interface to the environment
return { getState, step, poleFell };
}
export function runPoleBench() {
function simulateCartPole(network) {
let fitness = 0;
// Initialize your cart-pole simulation environment here
// For simplicity, assume we have a function that initializes the environment
// and returns an object with methods to step (update) the environment based on the network's action
// and to check if the pole is still balanced.
const environment = initializeCartPoleEnvironment();
let done = false;
while (!done) {
const input = environment.getState();
const action = Math.round(activateNetwork(network, input)); // Assume this returns a discrete action (e.g., 0 for left, 1 for right)
done = environment.step(action); // Update the environment with the chosen action
fitness += 1; // Increment fitness for each timestep the pole remains balanced
if (environment.poleFell() || fitness >= MAX_TIMESTEPS) {
break;
}
}
return fitness;
}
function fitnessFunction(genome, growth) {
let fitness = simulateCartPole(genome);
fitness -=
(genome.nodes.length -
genome.inputSize -
genome.outputSize +
genome.connections.length) *
growth;
return fitness;
}
const populationSize = 100;
const elitism = 10;
const mutationRate = 0.3;
const growth = 0.0001;
const maxGenerations = 1000;
const targetFitness = 1000;
const networkTemplate = buildNetwork(4, 1);
let population = createPopulation({
populationSize,
networkTemplate,
});
let generation = 0;
let bestFitness = 0;
let bestGenome = null;
while (generation < maxGenerations && bestFitness < targetFitness) {
population.forEach((genome) => {
const fitness = fitnessFunction(genome, growth);
genome.score = fitness;
if (fitness > bestFitness) {
bestGenome = copyNetwork(genome);
bestFitness = fitness;
}
});
const sortedByFitness = sortPopulation(population);
const elites = sortedByFitness.slice(0, elitism);
const newPopulation = [...elites];
while (newPopulation.length < populationSize) {
const offspring = getOffspring(sortedByFitness);
newPopulation.push(offspring);
}
mutatePopulation(newPopulation.slice(elitism), {
mutationRate: mutationRate,
});
population = newPopulation;
generation++;
}
console.log("Evolution complete");
console.log(`Best fitness achieved: ${bestFitness}`);
console.log(`Generations: ${generation}`);
console.log(`Best Genome nodes: ${bestGenome.nodes.length}`);
console.log(`Best Genome conns: ${bestGenome.connections.length}`);
return {
bestGenome,
};
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment