Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save stevenRush/14c4343cb720adda48de25b3c7114da0 to your computer and use it in GitHub Desktop.
Save stevenRush/14c4343cb720adda48de25b3c7114da0 to your computer and use it in GitHub Desktop.
class DQNAgent:
def __init__(self, name):
with tf.variable_score(name):
self.conv1 = Conv2D(32, 8, (4, 4), padding="same", activation="relu")
self.pool1 = MaxPooling2D((2, 2))
self.conv2 = Conv2D(64, 4, (2, 2), padding="same", activation="relu")
self.pool2 = MaxPooling2D((2, 2))
self.conv3 = Conv2D(64, 3, (1, 1), padding="same", activation"relu")
self.pool3 = MaxPooling2D((2, 2))
self.flatten = Flatten()
self.fc1 = Dense(256, activation="relu")
self.qvalues = Dense(ACTIONS)
self.state = tf.placeholder("float", [None, 80, 80, 4])
self.qvalues = self.get_symbolic_qvalues(self.state)
def get_symbolic_qvalues(self, state_t):
conv1 = self.conv1(state_t)
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
conv3 = self.conv3(pool2)
pool3 = self.pool3(conv3)
flatten = self.flatten(pool3)
fc1 = self.fc1(flatten)
return self.qvalues(fc1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment