Skip to content

Instantly share code, notes, and snippets.

@rayheberer
Created August 17, 2018 00:29
Show Gist options
  • Save rayheberer/f5f51e0afa4819531558515780b7d1ad to your computer and use it in GitHub Desktop.
Save rayheberer/f5f51e0afa4819531558515780b7d1ad to your computer and use it in GitHub Desktop.
class A2CAtari(base_agent.BaseAgent):
# ...
# ...
def _get_batch(self, terminal):
# ...
# ...
# calculate discounted rewards
raw_rewards = list(self.reward_buffer)
if terminal:
value = 0
else:
value = np.squeeze(self.sess.run(
self.network.value_estimate,
feed_dict={self.network.screen_features: screen[-1:],
self.network.minimap_features: minimap[-1:],
self.network.flat_features: flat[-1:]}))
returns = []
# n-step discounted rewards from 1 < n < trajectory_training_steps
for i, reward in enumerate(raw_rewards):
value = reward + self.discount_factor * value
returns.append(value)
# ...
# ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment