Last active
April 10, 2020 01:13
-
-
Save nagataka/dfd966951374aaa194401a7f3bb215dd to your computer and use it in GitHub Desktop.
A template to start a project using OpenAI gym with PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A template to implement RL agent with OpenAI Gym | |
Usage: python ./gym_template.py --env=CarRacing-v0 --algo=policy_gradient --epochs 1 | |
implementation of algorithms need to be ./algorithms/ directory, or change the following line to your env | |
> algo = import_module('algorithms.'+args.algo) | |
""" | |
import argparse | |
import numpy as np | |
import gym | |
from importlib import import_module | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
def main(): | |
parser = argparse.ArgumentParser(description="Main for training RL agent") | |
parser.add_argument('--gamma', type=float, default=0.99, metavar='G', help='discount factor (default: 0.99)') | |
parser.add_argument('--epochs', type=int, default=1000, metavar='N', | |
help='number of epochs to train (default: 1,000)') | |
parser.add_argument('--env', type=str, default=None, help="https://github.com/openai/gym/wiki/Table-of-environments") | |
parser.add_argument('--algo', type=str, default='dqn', help="learning algorithm") | |
args = parser.parse_args() | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
env_str = args.env | |
env = gym.make(env_str) | |
if isinstance(self.env.action_space, gym.spaces.box.Box): | |
self.num_actions = self.env.action_space.shape[0] | |
else: | |
self.num_actions = self.env.action_space.n | |
num_observations = env.observation_space | |
print("Created {} env which has {} actions in {} spaces".format(env_str, num_actions, num_observations) ) | |
algo = import_module('algorithms.'+args.algo) | |
agent = Agent(algo, env) | |
agent.train(args.epochs) | |
class Agent(): | |
def __init__(self, algo, env): | |
self.algo = algo | |
self.env = env | |
def train(self, num_epochs): | |
for e in range(num_epochs): | |
# initialization | |
state = self.env.reset() | |
done = False | |
total_reward = 0 | |
self.env.render() | |
while not done: | |
action = self.env.action_space.sample() | |
next_state, reward, done, _ = self.env.step(action) | |
self.env.render() | |
total_reward += reward | |
state = next_state | |
print("Done with reward ", total_reward) | |
self.env.close() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment