Last active
April 29, 2017 17:43
-
-
Save JonathanRaiman/bcd3066f73d29e633dda8440f9ac385a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from deap import algorithms, base, creator, tools | |
import numpy as np | |
domains = 100 | |
num_entities = 10000 | |
entity_num_domains = 5 | |
num_mentions = 200 | |
classifications = np.random.binomial( | |
1, np.ones(domains) * entity_num_domains / domains, size=(num_entities, domains) | |
).astype(np.bool) | |
domains = 1000 | |
classifications = np.pad(classifications, [(0, 0), (0, domains - classifications.shape[1])], | |
mode="constant") | |
mentions = [] | |
for i in range(num_mentions): | |
num_ents = np.random.randint(2, 20) | |
ents = np.random.choice(num_entities, size=num_ents) | |
counts = np.random.randint(1, 100, size=len(ents)) | |
right_ent = ents[0] | |
mentions.append( | |
(right_ent, ents, counts) | |
) | |
def rollout(individual, penalty=0.01): | |
individual = np.array(individual).astype(np.bool) | |
subclassification = classifications[:, individual] | |
score = -individual.sum() * penalty | |
sample_sum = individual.sum() | |
for right, others, counts in mentions: | |
if sample_sum == 0: | |
if others[np.argmax(counts)] == right: | |
score += 1 | |
else: | |
indices = np.all(subclassification[others, :] == subclassification[right, :][None, :], axis=1) | |
subset = others[indices] | |
if len(subset) == 1: | |
score += 1 | |
else: | |
if subset[np.argmax(counts[indices])] == right: | |
score += 1 | |
return (score,) | |
if method == "ga": | |
creator.create("FitnessMax", base.Fitness, weights=(1.0,)) | |
creator.create("Individual", np.ndarray, fitness=creator.FitnessMax) | |
history = tools.History() | |
stats = tools.Statistics(lambda ind: ind.fitness.values) | |
stats.register("avg", np.mean) | |
stats.register("min", np.min) | |
stats.register("max", np.max) | |
toolbox = base.Toolbox() | |
toolbox.register("attr_bool", random.randint, 0, 1) | |
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_bool, n=domains) | |
toolbox.register("population", tools.initRepeat, list, toolbox.individual) | |
toolbox.register("evaluate", rollout) | |
toolbox.register("mate", tools.cxTwoPoint) | |
toolbox.register("mutate", tools.mutFlipBit, indpb=0.05) | |
toolbox.decorate("mate", history.decorator) | |
toolbox.decorate("mutate", history.decorator) | |
toolbox.register("select", tools.selTournament, tournsize=5) | |
pop = [creator.Individual( | |
np.random.binomial(1, p=10/domains, size=(domains)).astype(np.bool) | |
) for i in range(300)] | |
algorithms.eaSimple(pop, toolbox, cxpb=0.5, mutpb=0.2, ngen=20, verbose=True, | |
stats=stats) | |
print(sum(tools.selBest(pop, k=1)[0])) | |
elif method == "reinforce": | |
import tensorflow as tf | |
def solve_with_reinforce(rollout, iterations=20, batch_size=20, lr=0.01, alpha=0.99): | |
tf.reset_default_graph() | |
session = tf.InteractiveSession() | |
state = tf.get_variable("Probs", | |
initializer=tf.constant_initializer(10.0 / domains), | |
shape=[domains], | |
dtype=tf.float32) | |
nsamples = tf.placeholder(tf.int32, [], name="nsamples") | |
bernoulli = tf.contrib.distributions.Bernoulli(p=state) | |
samples = bernoulli.sample_n(nsamples) | |
prob_samples = tf.reduce_sum(bernoulli.log_pmf(samples), 1) | |
rewards = tf.placeholder(tf.float32, [None], name="rewards") | |
loss = -(prob_samples * rewards) | |
train_op = tf.train.GradientDescentOptimizer(learning_rate=lr).minimize(loss) | |
with tf.control_dependencies([train_op]): | |
update_op = tf.assign(state, tf.clip_by_value(state, 1e-6, 1.0 - 1e-6)) | |
session.run(tf.global_variables_initializer()) | |
for i in range(iterations): | |
batch_samples = session.run(samples, {nsamples: batch_size}) | |
batch_rewards = np.zeros(batch_size, dtype=np.float32) | |
for row, sample in enumerate(batch_samples.astype(np.bool)): | |
batch_rewards[row] = rollout(sample)[0] | |
norm_rewards = batch_rewards - np.mean(batch_rewards) | |
norm_rewards = norm_rewards / np.std(batch_rewards) | |
_, session.run(update_op, {samples: batch_samples, rewards: norm_rewards}) | |
print("{}, max: {}, min: {}, mean: {}".format( | |
i, | |
np.max(batch_rewards), | |
np.min(batch_rewards), | |
np.mean(batch_rewards)), flush=True) | |
solve_with_reinforce(rollout, batch_size=20, lr=0.01, iterations=50) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment