Created
August 5, 2020 13:09
-
-
Save ChuaCheowHuan/c9a3591b0b2d77bbed2eb50dc37c2ab1 to your computer and use it in GitHub Desktop.
Testing lr_schedule for ddppo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gym.spaces import Discrete | |
import numpy as np | |
import ray | |
from ray.rllib.env.multi_agent_env import MultiAgentEnv | |
from ray.rllib.models import ModelCatalog | |
from ray.rllib.agents.ppo import ddppo | |
from ray.rllib.agents.ppo import DDPPOTrainer | |
from ray.rllib.utils.framework import try_import_tf | |
from ray.tune.logger import pretty_print | |
tf = try_import_tf() | |
ROCK = 0 | |
PAPER = 1 | |
SCISSORS = 2 | |
class RockPaperScissorsEnv(MultiAgentEnv): | |
"""Two-player environment for rock paper scissors. | |
The observation is simply the last opponent action.""" | |
def __init__(self, _): | |
self.action_space = Discrete(3) | |
self.observation_space = Discrete(3) | |
self.player1 = "player1" | |
self.player2 = "player2" | |
self.last_move = None | |
self.num_moves = 0 | |
def reset(self): | |
self.P1_eps_r = 0 | |
self.P2_eps_r = 0 | |
self.last_move = (0, 0) | |
self.num_moves = 0 | |
return { | |
self.player1: self.last_move[1], | |
self.player2: self.last_move[0], | |
} | |
def step(self, action_dict): | |
move1 = action_dict[self.player1] | |
move2 = action_dict[self.player2] | |
self.last_move = (move1, move2) | |
obs = { | |
self.player1: self.last_move[1], | |
self.player2: self.last_move[0], | |
} | |
r1, r2 = { | |
(ROCK, ROCK): (0.0, 0.0), | |
(ROCK, PAPER): (-1.0, 1.0), | |
(ROCK, SCISSORS): (1.0, -1.0), | |
(PAPER, ROCK): (1.0, -1.0), | |
(PAPER, PAPER): (0.0, 0.0), | |
(PAPER, SCISSORS): (-1.0, 1.0), | |
(SCISSORS, ROCK): (-1.0, 1.0), | |
(SCISSORS, PAPER): (1.0, -1.0), | |
(SCISSORS, SCISSORS): (0.0, 0.0), | |
}[move1, move2] | |
rew = { | |
self.player1: r1, | |
self.player2: r2, | |
} | |
self.num_moves += 1 | |
done = { | |
"__all__": self.num_moves >= 10, | |
} | |
return obs, rew, done, {} | |
train_policy_list = ["agt_0", "agt_1"] | |
use_lstm=False | |
lr = 0.1 | |
gamma = 0.9 | |
policies = {"agt_0": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}, | |
"lr": lr, | |
"gamma": gamma,}), | |
"agt_1": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}, | |
"lr": lr, | |
"gamma": gamma,}), | |
} | |
def select_policy(agent_id): | |
if agent_id == "player1": | |
return "agt_0" | |
else: | |
return "agt_1" | |
def my_config(): | |
config = ddppo.DEFAULT_CONFIG.copy() | |
config["multiagent"] = {"policies_to_train": train_policy_list, | |
"policies": policies, | |
"policy_mapping_fn": select_policy, | |
} | |
config["num_cpus_per_worker"] = 0.25 | |
config["num_gpus_per_worker"] = 0.25 | |
config["num_workers"] = 2 | |
config["log_level"] = "WARN" # WARN/INFO/DEBUG | |
config["lr_schedule"] = [[0, 0.1], [1000, 0.00001]] | |
return config | |
def go_train(): | |
trainer = DDPPOTrainer(config=my_config(), env=RockPaperScissorsEnv) | |
for i in range(100): | |
result = trainer.train() | |
print(pretty_print(result)) | |
ray.init(ignore_reinit_error=True, log_to_driver=True, num_cpus=2, num_gpus=1) | |
go_train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment