Skip to content

Instantly share code, notes, and snippets.

@fedden
Created December 13, 2017 15:42
Show Gist options
  • Save fedden/d12b67c4ddd894aa00592efb7fec521f to your computer and use it in GitHub Desktop.
Save fedden/d12b67c4ddd894aa00592efb7fec521f to your computer and use it in GitHub Desktop.
def vectorised_dfo(solution_size=5, population_size=1000000, iteration_count=100, disturbance_threshold=0.1, lower=-300.0, upper=300.0, print_results=False):
target_solution = tf.zeros((solution_size,))
target_size = int(target_solution.get_shape()[0])
mean = np.mean([lower, upper])
std = np.std([lower, upper])
population = tf.random_normal(shape=[population_size, target_size], mean=mean, stddev=std)
for _ in range(iteration_count):
fitnesses = tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(population, target_solution)), axis=1))
swarms_best_index = tf.argmin(fitnesses)
swarms_best = tf.gather(population, swarms_best_index)
if print_results: print(tf.gather(fitnesses, swarms_best_index))
top_population = tf.reshape(tensor=population[0], shape=[1, target_size])
middle_population = population[1:-1]
bottom_population = tf.reshape(tensor=population[-1], shape=[1, target_size])
population_up = tf.concat([middle_population, bottom_population, top_population], axis=0)
population_down = tf.concat([bottom_population, top_population, middle_population], axis=0)
fitnesses_up = tf.concat([fitnesses[1:-1],
tf.reshape(fitnesses[-1], shape=[1]),
tf.reshape(fitnesses[0], shape=[1])], axis=0)
fitnesses_down = tf.concat([tf.reshape(fitnesses[-1], shape=[1]),
tf.reshape(fitnesses[0], shape=[1]),
fitnesses[1:-1]], axis=0)
best_neighbours = tf.where(fitnesses_up < fitnesses_down, x=population_up, y=population_down)
disturbance_rolls = tf.random_uniform(shape=[population_size, target_size], maxval=1.0)
random_resets = tf.random_normal(shape=[population_size, target_size], mean=mean, stddev=std)
move_amount = tf.random_uniform(shape=[population_size, target_size], maxval=1.0)
fly_update = best_neighbours + move_amount * (swarms_best - best_neighbours)
population = tf.where(disturbance_rolls < disturbance_threshold, random_resets, fly_update)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment