Last active
January 1, 2020 01:51
-
-
Save rayheberer/fa1842825cdfb2952c1b3cd3192be343 to your computer and use it in GitHub Desktop.
Train step implemented
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
def dense(x, weights, bias, activation=tf.identity, **activation_kwargs): | |
"""Dense layer.""" | |
z = tf.matmul(x, weights) + bias | |
return activation(z, **activation_kwargs) | |
def init_weights(shape, initializer): | |
"""Initialize weights for tensorflow layer.""" | |
weights = tf.Variable( | |
initializer(shape), | |
trainable=True, | |
dtype=tf.float32 | |
) | |
return weights | |
class Network(object): | |
"""Q-function approximator.""" | |
def __init__(self, | |
input_size, | |
output_size, | |
hidden_size=[50, 50], | |
weights_initializer=tf.initializers.glorot_uniform(), | |
bias_initializer=tf.initializers.zeros(), | |
optimizer=tf.optimizers.Adam, | |
**optimizer_kwargs): | |
"""Initialize weights and hyperparameters.""" | |
self.input_size = input_size | |
self.output_size = output_size | |
self.hidden_size = hidden_size | |
np.random.seed(41) | |
self.initialize_weights(weights_initializer, bias_initializer) | |
self.optimizer = optimizer(**optimizer_kwargs) | |
def initialize_weights(self, weights_initializer, bias_initializer): | |
"""Initialize and store weights.""" | |
wshapes = [ | |
[self.input_size, self.hidden_size[0]], | |
[self.hidden_size[0], self.hidden_size[1]], | |
[self.hidden_size[1], self.output_size] | |
] | |
bshapes = [ | |
[1, self.hidden_size[0]], | |
[1, self.hidden_size[1]], | |
[1, self.output_size] | |
] | |
self.weights = [init_weights(s, weights_initializer) for s in wshapes] | |
self.biases = [init_weights(s, bias_initializer) fr s in bshapes] | |
self.trainable_variables = self.weights + self.biases | |
def model(self, inputs): | |
"""Given a state vector, return the Q values of actions.""" | |
h1 = dense(inputs, self.weights[0], self.biases[0], tf.nn.relu) | |
h2 = dense(h1, self.weights[1], self.biases[1], tf.nn.relu) | |
out = dense(h2, self.weights[2], self.biases[2]) | |
return out | |
def train_step(self, inputs, targets, actions_one_hot): | |
"""Update weights.""" | |
with tf.GradientTape() as tape: | |
qvalues = tf.squeeze(self.model(inputs)) | |
preds = tf.reduce_sum(qvalues * actions_one_hot, axis=1) | |
loss = tf.losses.mean_squared_error(targets, preds) | |
grads = tape.gradient(loss, self.trainable_variables) | |
self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) | |
class Agent(object): | |
"""Deep Q-learning agent.""" | |
def __init__(self, | |
state_space_size, | |
action_space_size, | |
discount=0.99): | |
"""Set parameters, initialize network.""" | |
self.action_space_size = action_space_size | |
self.online_network = Network(state_space_size, action_space_size) | |
self.target_network = Network(state_space_size, action_space_size) | |
self.update_target_network() | |
self.discount = discount | |
self.experience_replay = None | |
def step(self, observation, training=True): | |
"""Observe state and rewards, select action.""" | |
pass | |
def update_target_network(self): | |
"""Update target network weights with current online network values.""" | |
variables = self.online_network.trainable_variables | |
variables_copy = [tf.Variable(v) for v in variables] | |
self.target_network.trainable_variables = variables_copy | |
def train_network(self): | |
"""Update online network weights.""" | |
inputs = None | |
actions = None | |
rewards = None | |
next_inputs = None | |
actions_one_hot = np.eye(self.action_space_size)[actions] | |
next_qvalues = np.squeeze(self.target_network.model(next_inputs)) | |
targets = rewards + self.discount * np.amax(next_qvalues, axis=-1) | |
self.online_network.train_step(inputs, targets, actions_one_hot) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment