Created
December 16, 2016 22:53
-
-
Save awjuliani/58d4826115fc35ed58c9f6786e19a2fa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Worker(): | |
.... | |
.... | |
.... | |
def work(self,max_episode_length,gamma,global_AC,sess,coord): | |
episode_count = 0 | |
total_step_count = 0 | |
print "Starting worker " + str(self.number) | |
with sess.as_default(), sess.graph.as_default(): | |
while not coord.should_stop(): | |
sess.run(self.update_local_ops) | |
episode_buffer = [] | |
episode_values = [] | |
episode_frames = [] | |
episode_reward = 0 | |
episode_step_count = 0 | |
d = False | |
self.env.new_episode() | |
s = self.env.get_state().screen_buffer | |
episode_frames.append(s) | |
s = process_frame(s) | |
rnn_state = self.local_AC.state_init | |
while self.env.is_episode_finished() == False: | |
#Take an action using probabilities from policy network output. | |
a_dist,v,rnn_state = sess.run([self.local_AC.policy,self.local_AC.value,self.local_AC.state_out], | |
feed_dict={self.local_AC.inputs:[s], | |
self.local_AC.state_in[0]:rnn_state[0], | |
self.local_AC.state_in[1]:rnn_state[1]}) | |
a = np.random.choice(a_dist[0],p=a_dist[0]) | |
a = np.argmax(a_dist == a) | |
r = self.env.make_action(self.actions[a]) / 100.0 | |
d = self.env.is_episode_finished() | |
if d == False: | |
s1 = self.env.get_state().screen_buffer | |
episode_frames.append(s1) | |
s1 = process_frame(s1) | |
else: | |
s1 = s | |
episode_buffer.append([s,a,r,s1,d,v[0,0]]) | |
episode_values.append(v[0,0]) | |
episode_reward += r | |
s = s1 | |
total_steps += 1 | |
episode_step_count += 1 | |
#Specific to VizDoom. We sleep the game for a specific time. | |
if self.sleep_time>0: | |
sleep(self.sleep_time) | |
# If the episode hasn't ended, but the experience buffer is full, then we | |
# make an update step using that experience rollout. | |
if len(episode_buffer) == 30 and d != True and episode_step_count != max_episode_length - 1: | |
# Since we don't know what the true final return is, we "bootstrap" from our current | |
# value estimation. | |
v1 = sess.run(self.local_AC.value, | |
feed_dict={self.local_AC.inputs:[s], | |
self.local_AC.state_in[0]:rnn_state[0], | |
self.local_AC.state_in[1]:rnn_state[1]})[0,0] | |
v_l,p_l,e_l,g_n,v_n = self.train(global_AC,episode_buffer,sess,gamma,v1) | |
episode_buffer = [] | |
sess.run(self.update_local_ops) | |
if d == True: | |
break | |
self.episode_rewards.append(episode_reward) | |
self.episode_lengths.append(episode_step_count) | |
self.episode_mean_values.append(np.mean(episode_values)) | |
# Update the network using the experience buffer at the end of the episode. | |
v_l,p_l,e_l,g_n,v_n = self.train(global_AC,episode_buffer,sess,gamma,0.0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment