This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class AC_Network(): | |
| def __init__(self,s_size,a_size,scope,trainer): | |
| .... | |
| .... | |
| .... | |
| if scope != 'global': | |
| self.actions = tf.placeholder(shape=[None],dtype=tf.int32) | |
| self.actions_onehot = tf.one_hot(self.actions,a_size,dtype=tf.float32) | |
| self.target_v = tf.placeholder(shape=[None],dtype=tf.float32) | |
| self.advantages = tf.placeholder(shape=[None],dtype=tf.float32) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Worker(): | |
| .... | |
| .... | |
| .... | |
| def work(self,max_episode_length,gamma,global_AC,sess,coord): | |
| episode_count = 0 | |
| total_step_count = 0 | |
| print "Starting worker " + str(self.number) | |
| with sess.as_default(), sess.graph.as_default(): | |
| while not coord.should_stop(): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copies one set of variables to another. | |
| # Used to set worker network parameters to those of global network. | |
| def update_target_graph(from_scope,to_scope): | |
| from_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, from_scope) | |
| to_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, to_scope) | |
| op_holder = [] | |
| for from_var,to_var in zip(from_vars,to_vars): | |
| op_holder.append(to_var.assign(from_var)) | |
| return op_holder |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| with tf.device("/cpu:0"): | |
| master_network = AC_Network(s_size,a_size,'global',None) # Generate global network | |
| num_workers = multiprocessing.cpu_count() # Set workers ot number of available CPU threads | |
| workers = [] | |
| # Create worker classes | |
| for i in range(num_workers): | |
| workers.append(Worker(DoomGame(),i,s_size,a_size,trainer,saver,model_path)) | |
| with tf.Session() as sess: | |
| coord = tf.train.Coordinator() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class AC_Network(): | |
| def __init__(self,s_size,a_size,scope,trainer): | |
| with tf.variable_scope(scope): | |
| #Input and visual encoding layers | |
| self.inputs = tf.placeholder(shape=[None,s_size],dtype=tf.float32) | |
| self.imageIn = tf.reshape(self.inputs,shape=[-1,84,84,1]) | |
| self.conv1 = slim.conv2d(activation_fn=tf.nn.elu, | |
| inputs=self.imageIn,num_outputs=16, | |
| kernel_size=[8,8],stride=[4,4],padding='VALID') | |
| self.conv2 = slim.conv2d(activation_fn=tf.nn.elu, |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| import numpy as np | |
| import tensorflow.contrib.slim as slim | |
| total_layers = 25 #Specify how deep we want our network | |
| units_between_stride = total_layers / 5 | |
| def denseBlock(input_layer,i,j): | |
| with tf.variable_scope("dense_unit"+str(i)): | |
| nodes = [] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| import numpy as np | |
| import tensorflow.contrib.slim as slim | |
| total_layers = 25 #Specify how deep we want our network | |
| units_between_stride = total_layers / 5 | |
| def highwayUnit(input_layer,i): | |
| with tf.variable_scope("highway_unit"+str(i)): | |
| H = slim.conv2d(input_layer,64,[3,3]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| import numpy as np | |
| import tensorflow.contrib.slim as slim | |
| total_layers = 25 #Specify how deep we want our network | |
| units_between_stride = total_layers / 5 | |
| def resUnit(input_layer,i): | |
| with tf.variable_scope("res_unit"+str(i)): | |
| part1 = slim.batch_norm(input_layer,activation_fn=None) |
NewerOlder