Last active
February 18, 2018 20:00
-
-
Save breeko/5b66f003f18e8f94c850d5a3bad7cc85 to your computer and use it in GitHub Desktop.
A portion of a larger program that uses experience replay object to train a q-network
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| if num_episode % update_freq == 0: | |
| for num_epoch in range(num_epochs): | |
| # Train batch is [[state,action,reward,next_state,done],...] | |
| train_batch = experience_replay.sample(batch_size) | |
| # Separate the batch into its components | |
| train_state, train_action, train_reward, \ | |
| train_next_state, train_done = train_batch.T | |
| # Convert the action array into an array of ints so they can be used for indexing | |
| train_action = train_action.astype(np.int) | |
| # Stack the train_state and train_next_state for learning | |
| train_state = np.vstack(train_state) | |
| train_next_state = np.vstack(train_next_state) | |
| # Our predictions (actions to take) from the main Q network | |
| target_q = target_qn.model.predict(train_state) | |
| # The Q values from our target network from the next state | |
| target_q_next_state = main_qn.model.predict(train_next_state) | |
| train_next_state_action = np.argmax(target_q_next_state,axis=1) | |
| train_next_state_action = train_next_state_action.astype(np.int) | |
| # Tells us whether game over or not | |
| # We will multiply our rewards by this value | |
| # to ensure we don't train on the last move | |
| train_gameover = train_done == 0 | |
| # Q value of the next state based on action | |
| train_next_state_values = target_q_next_state[range(batch_size), train_next_state_action] | |
| # Reward from the action chosen in the train batch | |
| actual_reward = train_reward + (y * train_next_state_values * train_gameover) | |
| target_q[range(batch_size), train_action] = actual_reward | |
| # Train the main model | |
| loss = main_qn.model.train_on_batch(train_state, target_q) | |
| losses.append(loss) | |
| # Update the target model with values from the main model | |
| update_target_graph(main_qn.model, target_qn.model, tau) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment