Skip to content

Instantly share code, notes, and snippets.

@breeko
Last active February 18, 2018 20:00
Show Gist options
  • Save breeko/5b66f003f18e8f94c850d5a3bad7cc85 to your computer and use it in GitHub Desktop.
Save breeko/5b66f003f18e8f94c850d5a3bad7cc85 to your computer and use it in GitHub Desktop.
A portion of a larger program that uses experience replay object to train a q-network
if num_episode % update_freq == 0:
for num_epoch in range(num_epochs):
# Train batch is [[state,action,reward,next_state,done],...]
train_batch = experience_replay.sample(batch_size)
# Separate the batch into its components
train_state, train_action, train_reward, \
train_next_state, train_done = train_batch.T
# Convert the action array into an array of ints so they can be used for indexing
train_action = train_action.astype(np.int)
# Stack the train_state and train_next_state for learning
train_state = np.vstack(train_state)
train_next_state = np.vstack(train_next_state)
# Our predictions (actions to take) from the main Q network
target_q = target_qn.model.predict(train_state)
# The Q values from our target network from the next state
target_q_next_state = main_qn.model.predict(train_next_state)
train_next_state_action = np.argmax(target_q_next_state,axis=1)
train_next_state_action = train_next_state_action.astype(np.int)
# Tells us whether game over or not
# We will multiply our rewards by this value
# to ensure we don't train on the last move
train_gameover = train_done == 0
# Q value of the next state based on action
train_next_state_values = target_q_next_state[range(batch_size), train_next_state_action]
# Reward from the action chosen in the train batch
actual_reward = train_reward + (y * train_next_state_values * train_gameover)
target_q[range(batch_size), train_action] = actual_reward
# Train the main model
loss = main_qn.model.train_on_batch(train_state, target_q)
losses.append(loss)
# Update the target model with values from the main model
update_target_graph(main_qn.model, target_qn.model, tau)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment