Skip to content

Instantly share code, notes, and snippets.

@rayheberer
Last active August 18, 2018 18:11
Show Gist options
  • Save rayheberer/d852d44031c494144dcbeea90ddb66e9 to your computer and use it in GitHub Desktop.
Save rayheberer/d852d44031c494144dcbeea90ddb66e9 to your computer and use it in GitHub Desktop.
class AtariNet(object):
# ...
# ...
def _build_optimization(self):
# ...
# ...
self.advantage = tf.subtract(
self.returns,
tf.squeeze(tf.stop_gradient(self.value_estimate)),
name="advantage")
# a2c gradient = policy gradient + value gradient + regularization
self.policy_gradient = -tf.reduce_mean(
(self.advantage *
tf.log(self.action_probability * self.args_probability)),
name="policy_gradient")
self.value_gradient = -tf.reduce_mean(
self.advantage * tf.squeeze(self.value_estimate),
name="value_gradient")
# only including function identifier entropy, not args
self.entropy = tf.reduce_sum(
self.function_policy * tf.log(self.function_policy),
name="entropy_regularization")
self.a2c_gradient = tf.add_n(
inputs=[self.policy_gradient,
self.value_gradient_strength * self.value_gradient,
self.regularization_strength * self.entropy],
name="a2c_gradient")
self.optimizer = tf.train.RMSPropOptimizer(
self.learning_rate).minimize(self.a2c_gradient,
global_step=self.global_step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment