Created
November 9, 2024 22:31
-
-
Save Sinjhin/8a5811a52d99758aedde42e86df1f1fe to your computer and use it in GitHub Desktop.
Just trying to implement NEAT from scratch while reading the paper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import math | |
import random | |
import copy | |
XOR_INPUTS = [ | |
([0, 0], 0), | |
([0, 1], 1), | |
([1, 0], 1), | |
([1, 1], 0) | |
] | |
# Constants for speciation | |
C1, C2, C3 = 1.0, 1.0, 0.4 # Compatibility distance coefficients | |
COMPATIBILITY_THRESHOLD = 6.0 # Threshold to determine if genomes are in the same species | |
STAGNATION_LIMIT = 15 # Number of generations without improvement before extinction | |
def sigmoid(x): | |
return 1 / (1 + math.exp(-x)) | |
class NodeGene: | |
def __init__(self, node_id, node_type): | |
self.node_id = node_id | |
self.node_type = node_type # 'input', 'output', 'hidden' | |
self.value = 1.0 if node_type == 'bias' else 0.0 | |
self.inputs = [] | |
def activate(self): | |
if self.node_type != 'input': | |
self.value = sigmoid(sum([inp.value * weight for inp, weight in self.inputs])) | |
class ConnectionGene: | |
def __init__(self, in_node, out_node, weight, enabled, innovation): | |
self.in_node = in_node | |
self.out_node = out_node | |
self.weight = weight | |
self.enabled = enabled | |
self.innovation = innovation | |
class Genome: | |
def __init__(self): | |
self.nodes = {} | |
self.connections = [] | |
self.fitness = 0.0 | |
self.global_innovation = 0 | |
def add_connection(self, in_node, out_node, weight): | |
conn = ConnectionGene(in_node, out_node, weight, True, self.global_innovation) | |
self.global_innovation += 1 | |
self.connections.append(conn) | |
out_node.inputs.append((in_node, weight)) | |
def mutate(self): | |
# Mutate weights | |
for conn in self.connections: | |
if random.random() < 0.8: # 80% chance to mutate weights | |
conn.weight += random.gauss(0, 1) | |
# Structural mutations (adding new nodes/connections) | |
if random.random() < 0.03: # 3% chance to add a new node | |
self.add_node_mutation() | |
if random.random() < 0.05: # 5% chance to add a new connection | |
self.add_connection_mutation() | |
def add_node_mutation(self): | |
# print("Adding a node") | |
# Randomly select a connection to split | |
enabled_connections = [c for c in self.connections if c.enabled] | |
if not enabled_connections: | |
print("No enabled connections available for node mutation.") | |
return | |
conn = random.choice(enabled_connections) | |
conn.enabled = False # Disable old connection | |
# Create new node | |
new_node = NodeGene(len(self.nodes), 'hidden') | |
self.nodes[new_node.node_id] = new_node | |
# Connect in_node -> new_node -> out_node | |
self.add_connection(conn.in_node, new_node, 1.0) | |
self.add_connection(new_node, conn.out_node, conn.weight) | |
def add_connection_mutation(self): | |
# print("Adding a connection") | |
# Add a connection between random nodes | |
node1 = random.choice(list(self.nodes.values())) | |
eligible_nodes = [n for n in self.nodes.values() if n != node1 and n.node_type != 'input'] | |
if eligible_nodes: | |
node2 = random.choice(eligible_nodes) | |
self.add_connection(node1, node2, random.uniform(-1, 1)) | |
else: | |
print("No eligible nodes available for connection mutation.") | |
def forward(self, inputs): | |
for i, node in enumerate([n for n in self.nodes.values() if n.node_type == 'input']): | |
node.value = inputs[i] | |
for node in self.nodes.values(): | |
node.activate() | |
return self.nodes[max(self.nodes.keys())].value | |
class Species: | |
def __init__(self, representative): | |
self.representative = representative | |
self.genomes = [] | |
self.best_fitness = 0.0 | |
self.generations_without_improvement = 0 | |
def add_genome(self, genome): | |
self.genomes.append(genome) | |
def calculate_shared_fitness(self): | |
for genome in self.genomes: | |
genome.fitness /= len(self.genomes) | |
def update_best_fitness(self): | |
# Track if fitness has improved | |
current_best = max(self.genomes, key=lambda g: g.fitness).fitness | |
if current_best > self.best_fitness: | |
self.best_fitness = current_best | |
self.generations_without_improvement = 0 | |
else: | |
self.generations_without_improvement += 1 | |
def is_extinct(self): | |
return self.generations_without_improvement >= STAGNATION_LIMIT | |
class Population: | |
def __init__(self, pop_size): | |
self.pop_size = pop_size | |
self.genomes = [self.create_initial_genome() for _ in range(pop_size)] | |
def create_initial_genome(self): | |
# Setup initial genome with input and output nodes | |
genome = Genome() | |
genome.nodes[0] = NodeGene(0, 'input') | |
genome.nodes[1] = NodeGene(1, 'input') | |
genome.nodes[2] = NodeGene(2, 'bias') | |
genome.nodes[3] = NodeGene(3, 'output') | |
genome.add_connection(genome.nodes[0], genome.nodes[3], random.uniform(-1, 1)) | |
genome.add_connection(genome.nodes[1], genome.nodes[3], random.uniform(-1, 1)) | |
genome.add_connection(genome.nodes[2], genome.nodes[3], random.uniform(-1, 1)) | |
return genome | |
def speciate(self): | |
# Clear current species and reassign genomes | |
self.species = [] | |
for genome in self.genomes: | |
found_species = False | |
for species in self.species: | |
if self.is_compatible(genome, species.representative): | |
species.add_genome(genome) | |
found_species = True | |
break | |
if not found_species: | |
new_species = Species(genome) | |
new_species.add_genome(genome) | |
self.species.append(new_species) | |
def is_compatible(self, genome1, genome2): | |
excess_genes = disjoint_genes = weight_diff = 0 | |
genome1_connections = {conn.innovation: conn for conn in genome1.connections} | |
genome2_connections = {conn.innovation: conn for conn in genome2.connections} | |
# Calculate differences | |
max_innov = max(max(genome1_connections), max(genome2_connections)) | |
for i in range(max_innov + 1): | |
conn1 = genome1_connections.get(i) | |
conn2 = genome2_connections.get(i) | |
if conn1 and conn2: | |
weight_diff += abs(conn1.weight - conn2.weight) | |
elif conn1 or conn2: | |
disjoint_genes += 1 if i < max_innov else 0 | |
excess_genes += 1 if i >= max_innov else 0 | |
# Normalize the weight difference | |
matching_genes = max(len(genome1_connections), len(genome2_connections)) | |
avg_weight_diff = weight_diff / matching_genes if matching_genes > 0 else 0 | |
# Compatibility distance formula | |
distance = C1 * excess_genes + C2 * disjoint_genes + C3 * avg_weight_diff | |
return distance < COMPATIBILITY_THRESHOLD | |
def evaluate_fitness(self, genome): | |
error = 0.0 | |
for inputs, target in XOR_INPUTS: | |
output = genome.forward(inputs) | |
error += (output - target) ** 2 | |
genome.fitness = 4 - error # Minimize error, max fitness is 4 | |
return genome.fitness | |
def evolve(self, generations): | |
for i in range(generations): | |
for genome in self.genomes: | |
self.evaluate_fitness(genome) | |
self.speciate() | |
self.adjust_fitness_within_species() | |
# Remove stagnant species | |
self.species = [species for species in self.species if not species.is_extinct()] | |
print(f"Species count after extinction: {len(self.species)}") | |
# Update each species' best fitness to track stagnation | |
for species in self.species: | |
species.update_best_fitness() | |
# Select and reproduce based on species fitness | |
next_gen = [] | |
for j, species in enumerate(self.species): | |
species_best = max(species.genomes, key=lambda g: g.fitness) | |
print(f"Gen: {i} - Species: {j} Champion with {len(species_best.connections)} cons, {len(species_best.nodes)}, and fitness {species_best.fitness}") | |
next_gen.append(copy.deepcopy(species_best)) # Preserve best from each species | |
while len(next_gen) < self.pop_size: | |
parent = random.choice(species.genomes) | |
child = self.mutate(parent) | |
next_gen.append(child) | |
self.genomes = next_gen | |
def adjust_fitness_within_species(self): | |
for species in self.species: | |
species.calculate_shared_fitness() | |
def reproduce(self): | |
# Simple reproduction: keep top 50%, mutate them, and replace bottom 50% | |
culled_gen = self.genomes[: self.pop_size // 2] | |
next_gen = [copy.deepcopy(genome) for genome in culled_gen] | |
for genome in culled_gen: | |
child = self.mutate(genome) | |
next_gen.append(child) | |
# Doubling the babais for a test | |
# for genome in culled_gen: | |
# child = self.mutate(genome) | |
# next_gen.append(child) | |
next_gen.append(culled_gen[0]) | |
self.genomes = next_gen | |
def mutate(self, genome): | |
child = copy.deepcopy(genome) | |
child.mutate() | |
return child | |
class NEATCore: | |
def __init__(self, loader_class=None, trainer_class=None, params=None): | |
self.name = 'NEATCore' | |
self.loader = loader_class(params) | |
self.trainer_class = trainer_class | |
self.params = params | |
self.cull_rate = params.get('cull_rate', 0.9) | |
self.nodes = [] | |
def run(self): | |
# Doing a super simple XOR first | |
population = Population(pop_size = 20) | |
print("Population:") | |
for genome in population.genomes: | |
print(f"Genome with {len(genome.connections)} cons, {len(genome.nodes)} nodes, and fitness {genome.fitness}") | |
solution = population.evolve(generations = 100000) | |
if solution: | |
print("Solution network weights:") | |
for conn in solution.connections: | |
print(f"From {conn.in_node.node_id} to {conn.out_node.node_id}, weight: {conn.weight}") | |
else: | |
print("No solution found within the specified generations.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment