Created
August 4, 2020 23:28
-
-
Save ericl/a049ab150596d183ca28dac1a2f60f9c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import gym | |
from gym.spaces import Discrete, Box | |
import numpy as np | |
import ray | |
from ray import tune | |
from ray.tune import grid_search | |
from ray.rllib.utils.framework import try_import_tf, try_import_torch | |
from ray.rllib.utils.test_utils import check_learning_achieved | |
tf1, tf, tfv = try_import_tf() | |
torch, nn = try_import_torch() | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--run", type=str, default="PPO") | |
parser.add_argument("--torch", action="store_true") | |
parser.add_argument("--as-test", action="store_true") | |
parser.add_argument("--stop-iters", type=int, default=50) | |
parser.add_argument("--stop-timesteps", type=int, default=100000) | |
parser.add_argument("--stop-reward", type=float, default=0.1) | |
class SimpleCorridor(gym.Env): | |
"""Example of a custom env in which you have to walk down a corridor. | |
You can configure the length of the corridor via the env config.""" | |
def __init__(self, config): | |
self.end_pos = config["corridor_length"] | |
self.cur_pos = 0 | |
self.action_space = Discrete(2) | |
self.observation_space = Discrete(self.end_pos + 1) | |
def reset(self): | |
self.cur_pos = 0 | |
return self.cur_pos | |
def step(self, action): | |
assert action in [0, 1], action | |
if action == 0 and self.cur_pos > 0: | |
self.cur_pos -= 1 | |
elif action == 1: | |
self.cur_pos += 1 | |
done = self.cur_pos >= self.end_pos | |
return self.cur_pos, 1.0 if done else -0.1, done, {} | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
ray.init() | |
config = { | |
"env": SimpleCorridor, # or "corridor" if registered above | |
"env_config": { | |
"corridor_length": 5, | |
}, | |
} | |
stop = { | |
"training_iteration": args.stop_iters, | |
"timesteps_total": args.stop_timesteps, | |
"episode_reward_mean": args.stop_reward, | |
} | |
results = tune.run(args.run, config=config, stop=stop) | |
if args.as_test: | |
check_learning_achieved(results, args.stop_reward) | |
ray.shutdown() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment