Created
May 26, 2016 17:42
-
-
Save domluna/656182c1dc0dddf360d83da24eb7748d to your computer and use it in GitHub Desktop.
For reference of using tf.get_variable
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def conv2d(inputs, filter, strides, name='conv2d'): | |
k = tf.get_variable('W', filter, initializer=xavier_initializer_conv2d()) | |
b = tf.get_variable('b', filter[-1], initializer=tf.constant_initializer(0.0)) | |
conv = tf.nn.conv2d(inputs, k, strides, 'SAME') | |
bias_add = tf.nn.bias_add(conv, b) | |
return tf.nn.relu(bias_add, name=name) | |
def vision_model(frames, n_frames): | |
with tf.variable_scope('Conv1') as scope: | |
conv1 = conv2d(frames, [8, 8, n_frames, 32], [1, 4, 4, 1], scope.name) | |
with tf.variable_scope('Conv2') as scope: | |
conv2 = conv2d(conv1, [4, 4, 32, 64], [1, 2, 2, 1], scope.name) | |
with tf.variable_scope('Conv3') as scope: | |
conv3 = conv2d(conv2, [3, 3, 64, 64], [1, 1, 1, 1], scope.name) | |
return conv3 | |
def atari_cnn(inputs, batch_size, n_actions, n_frames, name='atari_cnn'): | |
""" | |
input 84 x 84 x 4 (image is 84x84, grayscaled) | |
32 filters 8 x 8 stride 4, relu | |
64 filters 4 x 4 stride 2, relu | |
64 filters 3 x 3 stride 1, relu | |
fc 512 units, relu | |
fc 4-18 units (actions) | |
""" | |
with tf.variable_scope(name): | |
conv3 = vision_model(inputs, n_frames) | |
reshaped = tf.reshape(conv3, [batch_size, -1]) | |
dim = reshaped.get_shape()[1].value | |
with tf.variable_scope('Dense1') as scope: | |
w = tf.get_variable('W', [dim, 512], initializer=xavier_initializer()) | |
b = tf.get_variable('b', [512], initializer=tf.constant_initializer(0.0)) | |
dense1 = tf.nn.relu(tf.matmul(reshaped, w) + b, name=scope.name) | |
with tf.variable_scope('Dense2') as scope: | |
w = tf.get_variable('W', [512, n_actions], initializer=xavier_initializer()) | |
b = tf.get_variable('b', [n_actions], initializer=tf.constant_initializer(0.0)) | |
dense2 = tf.add(tf.matmul(dense1, w), b, name=scope.name) | |
return dense2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment