Skip to content

Instantly share code, notes, and snippets.

@rayheberer
Created August 17, 2018 00:25
Show Gist options
  • Save rayheberer/72a3cfbda4368c063ee881d660367f79 to your computer and use it in GitHub Desktop.
Save rayheberer/72a3cfbda4368c063ee881d660367f79 to your computer and use it in GitHub Desktop.
class AtariNet(object):
# ...
# ...
def _build(self):
# ...
# ...
# convolutional layers for minimap features
self.minimap_conv1 = tf.layers.conv2d(
inputs=self.minimap_processed,
filters=16,
kernel_size=[8, 8],
strides=[4, 4],
padding="SAME",
name="minimap_conv1")
self.minimap_activation1 = tf.nn.relu(
self.minimap_conv1,
name="minimap_activation1")
self.minimap_conv2 = tf.layers.conv2d(
inputs=self.minimap_activation1,
filters=32,
kernel_size=[4, 4],
strides=[2, 2],
padding="SAME",
name="minimap_conv2")
self.minimap_activation2 = tf.nn.relu(
self.minimap_conv2,
name="minimap_activation2")
# linear layer for non-spatial features (tanh activation)
self.flat_linear = tf.layers.dense(
inputs=self.flat_processed,
units=64,
activation=tf.tanh,
name="flat_linear")
# flatten and concatenate
self.screen_flat = tf.layers.flatten(
self.screen_activation2,
name="screen_flat")
self.minimap_flat = tf.layers.flatten(
self.minimap_activation2,
name="minimap_flat")
self.concat = tf.concat(
values=[self.screen_flat, self.minimap_flat, self.flat_linear],
axis=1,
name="concat")
# linear layer with ReLU activation
self.state_representation = tf.layers.dense(
inputs=self.concat,
units=256,
activation=tf.nn.relu,
name="state_representation")
# ...
# ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment