Skip to content

Instantly share code, notes, and snippets.

@tristansokol
Created April 28, 2018 05:18
Show Gist options
  • Save tristansokol/56d7cea81791adadffee7f0f5db8c724 to your computer and use it in GitHub Desktop.
Save tristansokol/56d7cea81791adadffee7f0f5db8c724 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""
A scripted agent called "Just Enough Retained Knowledge".
"""
import random
import gym
import numpy as np
import gym_remote.client as grc
import gym_remote.exceptions as gre
EMA_RATE = 0.2
EXPLOIT_BIAS = 0.25
TOTAL_TIMESTEPS = int(1e6)
def main():
"""Run JERK on the attached environment."""
env = grc.RemoteEnv('tmp/sock')
env = TrackedEnv(env)
new_ep = True
solutions = []
while True:
if new_ep:
if (solutions and
random.random() < EXPLOIT_BIAS + env.total_steps_ever / TOTAL_TIMESTEPS):
solutions = sorted(solutions, key=lambda x: np.mean(x[0]))
best_pair = solutions[-1]
new_rew = exploit(env, best_pair[1])
best_pair[0].append(new_rew)
print('replayed best with reward %f' % new_rew)
continue
else:
env.reset()
new_ep = False
rew, new_ep = move(env, 100)
if not new_ep and rew <= 0:
print('backtracking due to negative reward: %f' % rew)
_, new_ep = move(env, 70, left=True)
if new_ep:
solutions.append(([max(env.reward_history)], env.best_sequence()))
def move(env, num_steps, left=False, jump_prob=1.0 / 10.0, jump_repeat=4):
"""
Move right or left for a certain number of steps,
jumping periodically.
"""
total_rew = 0.0
done = False
steps_taken = 0
jumping_steps_left = 0
while not done and steps_taken < num_steps:
action = np.zeros((12,), dtype=np.bool)
action[6] = left
action[7] = not left
if jumping_steps_left > 0:
action[0] = True
jumping_steps_left -= 1
else:
if random.random() < jump_prob:
jumping_steps_left = jump_repeat - 1
action[0] = True
_, rew, done, _ = env.step(action)
total_rew += rew
steps_taken += 1
if done:
break
return total_rew, done
def exploit(env, sequence):
"""
Replay an action sequence; pad with NOPs if needed.
Returns the final cumulative reward.
"""
env.reset()
done = False
idx = 0
while not done:
if idx >= len(sequence):
_, _, done, _ = env.step(np.zeros((12,), dtype='bool'))
else:
_, _, done, _ = env.step(sequence[idx])
idx += 1
return env.total_reward
class TrackedEnv(gym.Wrapper):
"""
An environment that tracks the current trajectory and
the total number of timesteps ever taken.
"""
def __init__(self, env):
super(TrackedEnv, self).__init__(env)
self.action_history = []
self.reward_history = []
self.total_reward = 0
self.total_steps_ever = 0
def best_sequence(self):
"""
Get the prefix of the trajectory with the best
cumulative reward.
"""
max_cumulative = max(self.reward_history)
for i, rew in enumerate(self.reward_history):
if rew == max_cumulative:
return self.action_history[:i+1]
raise RuntimeError('unreachable')
# pylint: disable=E0202
def reset(self, **kwargs):
self.action_history = []
self.reward_history = []
self.total_reward = 0
return self.env.reset(**kwargs)
def step(self, action):
self.total_steps_ever += 1
self.action_history.append(action.copy())
obs, rew, done, info = self.env.step(action)
self.total_reward += rew
self.reward_history.append(self.total_reward)
return obs, rew, done, info
if __name__ == '__main__':
try:
main()
except gre.GymRemoteError as exc:
print('exception', exc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment