Skip to content

Instantly share code, notes, and snippets.

@zou3519
Created October 31, 2017 16:47
Show Gist options
  • Save zou3519/fd6a2dc111c373e71da867b0ff1aceed to your computer and use it in GitHub Desktop.
Save zou3519/fd6a2dc111c373e71da867b0ff1aceed to your computer and use it in GitHub Desktop.
import argparse
import gym
import numpy as np
from itertools import count
import torch
import torch.distributions as D
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
help='random seed (default: 543)')
parser.add_argument('--render', action='store_true',
help='render the environment')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs (default: 10)')
args = parser.parse_args()
env = gym.make('CartPole-v0')
env.seed(args.seed)
torch.manual_seed(args.seed)
class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.affine1 = nn.Linear(4, 128)
self.affine2 = nn.Linear(128, 2)
self.saved_logprobs = []
self.rewards = []
def forward(self, x):
x = F.relu(self.affine1(x))
action_scores = self.affine2(x)
return F.softmax(action_scores, dim=0)
policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
def select_action(state):
state = torch.from_numpy(state).float()
probs = policy(Variable(state))
distribution = D.Multinomial(probs)
action = distribution.sample()
logprob = distribution.log_prob(action)
policy.saved_logprobs.append(logprob.unsqueeze(0))
return action.data
def finish_episode():
R = 0
rewards = []
for r in policy.rewards[::-1]:
R = r + args.gamma * R
rewards.insert(0, R)
rewards = torch.Tensor(rewards)
rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)
logprobs = torch.cat(policy.saved_logprobs, 0)
loss = (-logprobs * Variable(rewards)).sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
del policy.rewards[:]
del policy.saved_logprobs[:]
running_reward = 10
for i_episode in count(1):
state = env.reset()
for _t in range(10000): # Don't infinite loop while learning
action = select_action(state)
state, reward, done, _ = env.step(action[0])
if args.render:
env.render()
policy.rewards.append(reward)
if done:
break
running_reward = running_reward * 0.99 + _t * 0.01
finish_episode()
if i_episode % args.log_interval == 0:
print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
i_episode, _t, running_reward))
if running_reward > env.spec.reward_threshold:
print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format(running_reward, _t))
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment