Skip to content

Instantly share code, notes, and snippets.

@rayheberer
Created August 17, 2018 00:28
Show Gist options
  • Save rayheberer/dcbdd8f0f3562b7a86b167c98648ee36 to your computer and use it in GitHub Desktop.
Save rayheberer/dcbdd8f0f3562b7a86b167c98648ee36 to your computer and use it in GitHub Desktop.
class A2CAtari(base_agent.BaseAgent):
# ...
# ...
def _sample_action(self,
screen_features,
minimap_features,
flat_features,
available_actions):
"""Sample actions and arguments from policy output layers."""
screen_features = np.expand_dims(screen_features, 0)
minimap_features = np.expand_dims(minimap_features, 0)
flat_features = np.expand_dims(flat_features, 0)
action_mask = np.zeros(len(FUNCTIONS), dtype=np.int32)
action_mask[available_actions] = 1
feed_dict = {self.network.screen_features: screen_features,
self.network.minimap_features: minimap_features,
self.network.flat_features: flat_features}
function_id_policy = self.sess.run(
self.network.function_policy,
feed_dict=feed_dict)
function_id_policy *= action_mask
function_ids = np.arange(len(function_id_policy))
# renormalize distribution over function identifiers
function_id_policy /= np.sum(function_id_policy)
# sample function identifier
action_id = np.random.choice(
function_ids,
p=np.squeeze(function_id_policy))
# ...
# ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment