Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created November 26, 2019 14:51
Show Gist options
  • Save pythonlessons/d78ff2a4d098ca0d15a6f3b1759d7c87 to your computer and use it in GitHub Desktop.
Save pythonlessons/d78ff2a4d098ca0d15a6f3b1759d7c87 to your computer and use it in GitHub Desktop.
1_Cartpole_DQN_replay_fucntion.py
def replay(self):
if len(self.memory) < self.train_start:
return
# Randomly sample minibatch from the memory
minibatch = random.sample(self.memory, min(len(self.memory), self.batch_size))
state = np.zeros((self.batch_size, self.state_size))
next_state = np.zeros((self.batch_size, self.state_size))
action, reward, done = [], [], []
# do this before prediction
# for speedup, this could be done on the tensor level
# but easier to understand using a loop
for i in range(self.batch_size):
state[i] = minibatch[i][0]
action.append(minibatch[i][1])
reward.append(minibatch[i][2])
next_state[i] = minibatch[i][3]
done.append(minibatch[i][4])
# do batch prediction to save speed
target = self.model.predict(state)
target_next = self.model.predict(next_state)
for i in range(self.batch_size):
# correction on the Q value for the action used
if done[i]:
target[i][action[i]] = reward[i]
else:
# Standard - DQN
# DQN chooses the max Q value among next actions
# selection and evaluation of action is on the target Q Network
# Q_max = max_a' Q_target(s', a')
target[i][action[i]] = reward[i] + self.gamma * (np.amax(target_next[i]))
# Train the Neural Network with batches
self.model.fit(state, target, batch_size=self.batch_size, verbose=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment