Skip to content

Instantly share code, notes, and snippets.

@Sinjhin
Created November 9, 2024 22:31
Show Gist options
  • Save Sinjhin/8a5811a52d99758aedde42e86df1f1fe to your computer and use it in GitHub Desktop.
Save Sinjhin/8a5811a52d99758aedde42e86df1f1fe to your computer and use it in GitHub Desktop.
Just trying to implement NEAT from scratch while reading the paper
import torch.nn as nn
import math
import random
import copy
XOR_INPUTS = [
([0, 0], 0),
([0, 1], 1),
([1, 0], 1),
([1, 1], 0)
]
# Constants for speciation
C1, C2, C3 = 1.0, 1.0, 0.4 # Compatibility distance coefficients
COMPATIBILITY_THRESHOLD = 6.0 # Threshold to determine if genomes are in the same species
STAGNATION_LIMIT = 15 # Number of generations without improvement before extinction
def sigmoid(x):
return 1 / (1 + math.exp(-x))
class NodeGene:
def __init__(self, node_id, node_type):
self.node_id = node_id
self.node_type = node_type # 'input', 'output', 'hidden'
self.value = 1.0 if node_type == 'bias' else 0.0
self.inputs = []
def activate(self):
if self.node_type != 'input':
self.value = sigmoid(sum([inp.value * weight for inp, weight in self.inputs]))
class ConnectionGene:
def __init__(self, in_node, out_node, weight, enabled, innovation):
self.in_node = in_node
self.out_node = out_node
self.weight = weight
self.enabled = enabled
self.innovation = innovation
class Genome:
def __init__(self):
self.nodes = {}
self.connections = []
self.fitness = 0.0
self.global_innovation = 0
def add_connection(self, in_node, out_node, weight):
conn = ConnectionGene(in_node, out_node, weight, True, self.global_innovation)
self.global_innovation += 1
self.connections.append(conn)
out_node.inputs.append((in_node, weight))
def mutate(self):
# Mutate weights
for conn in self.connections:
if random.random() < 0.8: # 80% chance to mutate weights
conn.weight += random.gauss(0, 1)
# Structural mutations (adding new nodes/connections)
if random.random() < 0.03: # 3% chance to add a new node
self.add_node_mutation()
if random.random() < 0.05: # 5% chance to add a new connection
self.add_connection_mutation()
def add_node_mutation(self):
# print("Adding a node")
# Randomly select a connection to split
enabled_connections = [c for c in self.connections if c.enabled]
if not enabled_connections:
print("No enabled connections available for node mutation.")
return
conn = random.choice(enabled_connections)
conn.enabled = False # Disable old connection
# Create new node
new_node = NodeGene(len(self.nodes), 'hidden')
self.nodes[new_node.node_id] = new_node
# Connect in_node -> new_node -> out_node
self.add_connection(conn.in_node, new_node, 1.0)
self.add_connection(new_node, conn.out_node, conn.weight)
def add_connection_mutation(self):
# print("Adding a connection")
# Add a connection between random nodes
node1 = random.choice(list(self.nodes.values()))
eligible_nodes = [n for n in self.nodes.values() if n != node1 and n.node_type != 'input']
if eligible_nodes:
node2 = random.choice(eligible_nodes)
self.add_connection(node1, node2, random.uniform(-1, 1))
else:
print("No eligible nodes available for connection mutation.")
def forward(self, inputs):
for i, node in enumerate([n for n in self.nodes.values() if n.node_type == 'input']):
node.value = inputs[i]
for node in self.nodes.values():
node.activate()
return self.nodes[max(self.nodes.keys())].value
class Species:
def __init__(self, representative):
self.representative = representative
self.genomes = []
self.best_fitness = 0.0
self.generations_without_improvement = 0
def add_genome(self, genome):
self.genomes.append(genome)
def calculate_shared_fitness(self):
for genome in self.genomes:
genome.fitness /= len(self.genomes)
def update_best_fitness(self):
# Track if fitness has improved
current_best = max(self.genomes, key=lambda g: g.fitness).fitness
if current_best > self.best_fitness:
self.best_fitness = current_best
self.generations_without_improvement = 0
else:
self.generations_without_improvement += 1
def is_extinct(self):
return self.generations_without_improvement >= STAGNATION_LIMIT
class Population:
def __init__(self, pop_size):
self.pop_size = pop_size
self.genomes = [self.create_initial_genome() for _ in range(pop_size)]
def create_initial_genome(self):
# Setup initial genome with input and output nodes
genome = Genome()
genome.nodes[0] = NodeGene(0, 'input')
genome.nodes[1] = NodeGene(1, 'input')
genome.nodes[2] = NodeGene(2, 'bias')
genome.nodes[3] = NodeGene(3, 'output')
genome.add_connection(genome.nodes[0], genome.nodes[3], random.uniform(-1, 1))
genome.add_connection(genome.nodes[1], genome.nodes[3], random.uniform(-1, 1))
genome.add_connection(genome.nodes[2], genome.nodes[3], random.uniform(-1, 1))
return genome
def speciate(self):
# Clear current species and reassign genomes
self.species = []
for genome in self.genomes:
found_species = False
for species in self.species:
if self.is_compatible(genome, species.representative):
species.add_genome(genome)
found_species = True
break
if not found_species:
new_species = Species(genome)
new_species.add_genome(genome)
self.species.append(new_species)
def is_compatible(self, genome1, genome2):
excess_genes = disjoint_genes = weight_diff = 0
genome1_connections = {conn.innovation: conn for conn in genome1.connections}
genome2_connections = {conn.innovation: conn for conn in genome2.connections}
# Calculate differences
max_innov = max(max(genome1_connections), max(genome2_connections))
for i in range(max_innov + 1):
conn1 = genome1_connections.get(i)
conn2 = genome2_connections.get(i)
if conn1 and conn2:
weight_diff += abs(conn1.weight - conn2.weight)
elif conn1 or conn2:
disjoint_genes += 1 if i < max_innov else 0
excess_genes += 1 if i >= max_innov else 0
# Normalize the weight difference
matching_genes = max(len(genome1_connections), len(genome2_connections))
avg_weight_diff = weight_diff / matching_genes if matching_genes > 0 else 0
# Compatibility distance formula
distance = C1 * excess_genes + C2 * disjoint_genes + C3 * avg_weight_diff
return distance < COMPATIBILITY_THRESHOLD
def evaluate_fitness(self, genome):
error = 0.0
for inputs, target in XOR_INPUTS:
output = genome.forward(inputs)
error += (output - target) ** 2
genome.fitness = 4 - error # Minimize error, max fitness is 4
return genome.fitness
def evolve(self, generations):
for i in range(generations):
for genome in self.genomes:
self.evaluate_fitness(genome)
self.speciate()
self.adjust_fitness_within_species()
# Remove stagnant species
self.species = [species for species in self.species if not species.is_extinct()]
print(f"Species count after extinction: {len(self.species)}")
# Update each species' best fitness to track stagnation
for species in self.species:
species.update_best_fitness()
# Select and reproduce based on species fitness
next_gen = []
for j, species in enumerate(self.species):
species_best = max(species.genomes, key=lambda g: g.fitness)
print(f"Gen: {i} - Species: {j} Champion with {len(species_best.connections)} cons, {len(species_best.nodes)}, and fitness {species_best.fitness}")
next_gen.append(copy.deepcopy(species_best)) # Preserve best from each species
while len(next_gen) < self.pop_size:
parent = random.choice(species.genomes)
child = self.mutate(parent)
next_gen.append(child)
self.genomes = next_gen
def adjust_fitness_within_species(self):
for species in self.species:
species.calculate_shared_fitness()
def reproduce(self):
# Simple reproduction: keep top 50%, mutate them, and replace bottom 50%
culled_gen = self.genomes[: self.pop_size // 2]
next_gen = [copy.deepcopy(genome) for genome in culled_gen]
for genome in culled_gen:
child = self.mutate(genome)
next_gen.append(child)
# Doubling the babais for a test
# for genome in culled_gen:
# child = self.mutate(genome)
# next_gen.append(child)
next_gen.append(culled_gen[0])
self.genomes = next_gen
def mutate(self, genome):
child = copy.deepcopy(genome)
child.mutate()
return child
class NEATCore:
def __init__(self, loader_class=None, trainer_class=None, params=None):
self.name = 'NEATCore'
self.loader = loader_class(params)
self.trainer_class = trainer_class
self.params = params
self.cull_rate = params.get('cull_rate', 0.9)
self.nodes = []
def run(self):
# Doing a super simple XOR first
population = Population(pop_size = 20)
print("Population:")
for genome in population.genomes:
print(f"Genome with {len(genome.connections)} cons, {len(genome.nodes)} nodes, and fitness {genome.fitness}")
solution = population.evolve(generations = 100000)
if solution:
print("Solution network weights:")
for conn in solution.connections:
print(f"From {conn.in_node.node_id} to {conn.out_node.node_id}, weight: {conn.weight}")
else:
print("No solution found within the specified generations.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment