Skip to content

Instantly share code, notes, and snippets.

@fedden
Created November 27, 2017 15:29
Show Gist options
  • Save fedden/f7586046581557359a02aa20fc5ce39a to your computer and use it in GitHub Desktop.
Save fedden/f7586046581557359a02aa20fc5ce39a to your computer and use it in GitHub Desktop.
# Imports
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation, rc
# A way of plotting each step of the optimisation process.
def plot(surface, positions, best_fitness, best_position, iteration):
plt.cla()
# Plot the feature / error surface.
plt.pcolormesh(surface[0], surface[1], surface[2])
# Plot all of the genepool.
x, y = zip(*positions)
plt.scatter(x, y, 1, 'k', edgecolors='face')
# Plot the best gene.
plt.scatter(best_position[0], best_position[1], 20, 'w', edgecolors='face')
# Add the title to the plot.
title = "iteration {}, fitness {}".format(iteration, best_fitness)
plt.title(title)
# Each time this function is ran, the optimisation process is furthered a step.
def run(iteration):
# Get the rastrigin error surface.
surface = function.get_surface()
# Get the gene pool.
positions = ga.get_solutions()
# Evaluate the fitnesses on the rastrigin surface
fitnesses = [function.evaluate(pos) for pos in positions]
# Inform the GA of the genepool's performance
ga.set_fitnesses(fitnesses)
# Get the best gene
best_position, best_fitness = ga.get_best()
# Plot the optimsation
plot(surface, positions, best_fitness, best_position, iteration)
# Genetic Algorithm parameters
elitism = 0.1
population_size = 500
mutation_rate = 0.8
mutation_sigma = 0.1
mutation_decay = 0.999
mutation_limit = 0.01
amount_optimisation_steps = 250
dna_bounds = (-5.11, 5.11)
dna_start_position = [4.8, 4.8]
# Construct the test function
function = Rastrigin()
# Construct the GA
ga = GA(len(dna_start_position),
dna_bounds,
dna_start_position,
elitism,
population_size,
mutation_rate,
mutation_sigma,
mutation_decay,
mutation_limit)
# Create the matplotlib figure
fig = plt.figure(figsize=(10, 10))
# See this on making gifs with matplotlib:
# https://matplotlib.org/api/animation_api.html
anim = animation.FuncAnimation(fig, run, frames=amount_optimisation_steps)
anim.save('rastrigin3_ga.mp4', fps=15)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment