Last active
February 14, 2020 12:58
-
-
Save hrzn/3f92d0b56d429b01fca8ff488443d7d9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Q_func = init_model() # We create our Q-function neural net, to be trained | |
memory = Memory() # We create a replay memory that we'll use to store experiences (state/action/reward tuples) | |
for episode in range(NR_EPISODES): | |
# sample a new random graph (coordinates and distance matrix) | |
coords, W = get_graph_mat(n=NR_NODES) | |
# initialize a partial solution - start from a random node | |
solution = [random.randint(0, NR_NODES-1)] | |
# current state | |
current_state = State(partial_solution=solution, W=W, coords=coords) | |
# cooling schedule for epsilon, the explore probability | |
epsilon = max(MIN_EPSILON, (1-EPSILON_DECAY_RATE)**episode) | |
while not is_state_final(current_state): | |
# Iterate until the TSP tour is complete | |
if epsilon >= random.random(): | |
# explore with probability EPSILON | |
next_node = get_next_neighbor_random(current_state) | |
else: | |
# take the best (greedy) action according to our function Q() | |
next_node, _ = Q_func.get_best_action(current_state) | |
next_solution = solution + [next_node] | |
# reward observed for taking this step - we just compute the travelled distance | |
reward = -(total_distance(next_solution, W) - total_distance(solution, W)) | |
next_state = State(partial_solution=next_solution, W=W, coords=coords) | |
# store our experience in memory (using 1-step Q-learning here) | |
memory.remember(Experience(state=state, | |
action=next_node, | |
reward=reward, | |
next_state=next_state)) | |
# update state and current solution | |
current_state = next_state | |
solution = next_solution | |
# sample a batch of states/actions/targets and take a gradient step | |
experiences = memory.sample_batch(BATCH_SIZE) | |
batch_states = [e.state for e in experiences] # states | |
batch_actions = [e.action for e in experiences] # actions | |
batch_targets = [] # targets | |
for i, experience in enumerate(experiences): | |
# Here we compute the targets by calling Q() on the next states. For convergence reasons, | |
# our final target will be the observed reward plus 0.9 times the reward estimated with Q(). | |
target = experience.reward | |
_, est_reward = Q_func.get_best_action(experience.next_state) | |
target += 0.9 * est_reward | |
batch_targets.append(target) | |
Q_func.batch_update(batch_states, batch_actions, batch_targets) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment