Skip to content

Instantly share code, notes, and snippets.

@elumixor
Last active May 24, 2020 22:27
Show Gist options
  • Save elumixor/143a80a14fc0f4cd61099d4395cc2303 to your computer and use it in GitHub Desktop.
Save elumixor/143a80a14fc0f4cd61099d4395cc2303 to your computer and use it in GitHub Desktop.
Medium TRPO Files
from collections import namedtuple
import gym
import torch
env = gym.make('CartPole-v0')
obs_size = env.observation_space.shape[0]
num_actions = env.action_space.n
Rollout = namedtuple('Rollout',
['states', 'actions', 'rewards', 'next_states', ])
def train(epochs=100, num_rollouts=10):
for epoch in range(epochs):
rollouts = []
for t in range(num_rollouts):
state = env.reset()
done = False
samples = []
while not done:
with torch.no_grad():
action = get_action(state)
next_state, reward, done, _ = env.step(action)
# Collect samples
samples.append((state, action, reward, next_state))
state = next_state
# Transpose our samples
states, actions, rewards, next_states = zip(*samples)
states = torch.stack([torch.from_numpy(state) for state in states], dim=0).float()
next_states = torch.stack([torch.from_numpy(state) for state in next_states], dim=0).float()
actions = torch.as_tensor(actions).unsqueeze(1)
rewards = torch.as_tensor(rewards).unsqueeze(1)
rollouts.append(Rollout(states, actions, rewards, next_states))
update_agent(rollouts)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment