Skip to content

Instantly share code, notes, and snippets.

@heiner
Created June 25, 2021 16:30
Show Gist options
  • Select an option

  • Save heiner/a54d10c47b9685c08acf807a15a85f6f to your computer and use it in GitHub Desktop.

Select an option

Save heiner/a54d10c47b9685c08acf807a15a85f6f to your computer and use it in GitHub Desktop.
class MemoryEnv:
# Max return per episode: 17.
def __init__(self):
self.q = []
def reset(self):
del self.q[:]
self.obs = torch.randn(4)
self.q.append(self.obs)
self.steps = 0
return self.obs
def step(self, action):
reward = 0.0
if len(self.q) >= 4:
if action == 0:
reward = 1 if self.q[-4][0] < 0 else -1
elif action == 1:
reward = 1 if self.q[-4][0] >= 0 else -1
else:
raise RuntimeError("bad action")
self.obs = torch.randn(4)
self.q.append(self.obs)
self.steps += 1
done = self.steps >= 20
return self.obs, reward, done
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment