Skip to content

Instantly share code, notes, and snippets.

@tristansokol
Created April 28, 2018 05:34
Show Gist options
  • Save tristansokol/385b599f9425c3cc8f53fbc8fb4c7e82 to your computer and use it in GitHub Desktop.
Save tristansokol/385b599f9425c3cc8f53fbc8fb4c7e82 to your computer and use it in GitHub Desktop.
def best_sequence(self):
"""
Get the prefix of the trajectory with the best
cumulative reward.
"""
max_cumulative = max(self.reward_history)
for i, rew in enumerate(self.reward_history):
if rew == max_cumulative:
return self.action_history[:i+1]
raise RuntimeError('unreachable')
# pylint: disable=E0202
def reset(self, **kwargs):
self.action_history = []
self.reward_history = []
self.total_reward = 0
return self.env.reset(**kwargs)
def step(self, action):
self.total_steps_ever += 1
self.action_history.append(action.copy())
obs, rew, done, info = self.env.step(action)
self.total_reward += rew
self.reward_history.append(self.total_reward)
return obs, rew, done, info
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment