Created
November 26, 2019 14:51
-
-
Save pythonlessons/d78ff2a4d098ca0d15a6f3b1759d7c87 to your computer and use it in GitHub Desktop.
1_Cartpole_DQN_replay_fucntion.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def replay(self): | |
| if len(self.memory) < self.train_start: | |
| return | |
| # Randomly sample minibatch from the memory | |
| minibatch = random.sample(self.memory, min(len(self.memory), self.batch_size)) | |
| state = np.zeros((self.batch_size, self.state_size)) | |
| next_state = np.zeros((self.batch_size, self.state_size)) | |
| action, reward, done = [], [], [] | |
| # do this before prediction | |
| # for speedup, this could be done on the tensor level | |
| # but easier to understand using a loop | |
| for i in range(self.batch_size): | |
| state[i] = minibatch[i][0] | |
| action.append(minibatch[i][1]) | |
| reward.append(minibatch[i][2]) | |
| next_state[i] = minibatch[i][3] | |
| done.append(minibatch[i][4]) | |
| # do batch prediction to save speed | |
| target = self.model.predict(state) | |
| target_next = self.model.predict(next_state) | |
| for i in range(self.batch_size): | |
| # correction on the Q value for the action used | |
| if done[i]: | |
| target[i][action[i]] = reward[i] | |
| else: | |
| # Standard - DQN | |
| # DQN chooses the max Q value among next actions | |
| # selection and evaluation of action is on the target Q Network | |
| # Q_max = max_a' Q_target(s', a') | |
| target[i][action[i]] = reward[i] + self.gamma * (np.amax(target_next[i])) | |
| # Train the Neural Network with batches | |
| self.model.fit(state, target, batch_size=self.batch_size, verbose=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment