Skip to content

Instantly share code, notes, and snippets.

@awjuliani
Created December 16, 2016 22:53
Show Gist options
  • Save awjuliani/58d4826115fc35ed58c9f6786e19a2fa to your computer and use it in GitHub Desktop.
Save awjuliani/58d4826115fc35ed58c9f6786e19a2fa to your computer and use it in GitHub Desktop.
class Worker():
....
....
....
def work(self,max_episode_length,gamma,global_AC,sess,coord):
episode_count = 0
total_step_count = 0
print "Starting worker " + str(self.number)
with sess.as_default(), sess.graph.as_default():
while not coord.should_stop():
sess.run(self.update_local_ops)
episode_buffer = []
episode_values = []
episode_frames = []
episode_reward = 0
episode_step_count = 0
d = False
self.env.new_episode()
s = self.env.get_state().screen_buffer
episode_frames.append(s)
s = process_frame(s)
rnn_state = self.local_AC.state_init
while self.env.is_episode_finished() == False:
#Take an action using probabilities from policy network output.
a_dist,v,rnn_state = sess.run([self.local_AC.policy,self.local_AC.value,self.local_AC.state_out],
feed_dict={self.local_AC.inputs:[s],
self.local_AC.state_in[0]:rnn_state[0],
self.local_AC.state_in[1]:rnn_state[1]})
a = np.random.choice(a_dist[0],p=a_dist[0])
a = np.argmax(a_dist == a)
r = self.env.make_action(self.actions[a]) / 100.0
d = self.env.is_episode_finished()
if d == False:
s1 = self.env.get_state().screen_buffer
episode_frames.append(s1)
s1 = process_frame(s1)
else:
s1 = s
episode_buffer.append([s,a,r,s1,d,v[0,0]])
episode_values.append(v[0,0])
episode_reward += r
s = s1
total_steps += 1
episode_step_count += 1
#Specific to VizDoom. We sleep the game for a specific time.
if self.sleep_time>0:
sleep(self.sleep_time)
# If the episode hasn't ended, but the experience buffer is full, then we
# make an update step using that experience rollout.
if len(episode_buffer) == 30 and d != True and episode_step_count != max_episode_length - 1:
# Since we don't know what the true final return is, we "bootstrap" from our current
# value estimation.
v1 = sess.run(self.local_AC.value,
feed_dict={self.local_AC.inputs:[s],
self.local_AC.state_in[0]:rnn_state[0],
self.local_AC.state_in[1]:rnn_state[1]})[0,0]
v_l,p_l,e_l,g_n,v_n = self.train(global_AC,episode_buffer,sess,gamma,v1)
episode_buffer = []
sess.run(self.update_local_ops)
if d == True:
break
self.episode_rewards.append(episode_reward)
self.episode_lengths.append(episode_step_count)
self.episode_mean_values.append(np.mean(episode_values))
# Update the network using the experience buffer at the end of the episode.
v_l,p_l,e_l,g_n,v_n = self.train(global_AC,episode_buffer,sess,gamma,0.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment