Skip to content

Instantly share code, notes, and snippets.

@rayheberer
Created August 17, 2018 00:27
Show Gist options
  • Save rayheberer/88108b64fb783954a395d24d49b35405 to your computer and use it in GitHub Desktop.
Save rayheberer/88108b64fb783954a395d24d49b35405 to your computer and use it in GitHub Desktop.
class AtariNet(object):
# ...
# ...
def _build(self):
# ...
# ...
# action function identifier policy
self.function_policy = tf.squeeze(tf.layers.dense(
inputs=self.state_representation,
units=NUM_ACTIONS,
activation=tf.nn.softmax),
name="function_policy")
# action function argument policies (nonspatial)
# action function argument placeholders (for optimization)
self.argument_policy = dict()
self.arguments = dict()
for arg_type in actions.TYPES:
# for spatial actions, represent each dimension independently
if len(arg_type.sizes) > 1:
if arg_type in SCREEN_TYPES:
units = self.screen_dimensions
elif arg_type in MINIMAP_TYPES:
units = self.minimap_dimensions
arg_policy_x = tf.layers.dense(
inputs=self.state_representation,
units=units[0],
activation=tf.nn.softmax)
arg_policy_y = tf.layers.dense(
inputs=self.state_representation,
units=units[1],
activation=tf.nn.softmax)
self.argument_policy[str(arg_type) + "x"] = arg_policy_x
self.argument_policy[str(arg_type) + "y"] = arg_policy_y
arg_placeholder_x = tf.placeholder(
tf.float32,
shape=[None, units[0]])
arg_placeholder_y = tf.placeholder(
tf.float32,
shape=[None, units[1]])
self.arguments[str(arg_type) + "x"] = arg_placeholder_x
self.arguments[str(arg_type) + "y"] = arg_placeholder_y
else:
arg_policy = tf.layers.dense(
inputs=self.state_representation,
units=arg_type.sizes[0],
activation=tf.nn.softmax)
self.argument_policy[str(arg_type)] = arg_policy
arg_placeholder = tf.placeholder(
tf.float32,
shape=[None, arg_type.sizes[0]])
self.arguments[str(arg_type)] = arg_placeholder
# ...
# ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment