Skip to content

Instantly share code, notes, and snippets.

@ChuaCheowHuan
Created August 5, 2020 13:09
Show Gist options
  • Save ChuaCheowHuan/c9a3591b0b2d77bbed2eb50dc37c2ab1 to your computer and use it in GitHub Desktop.
Save ChuaCheowHuan/c9a3591b0b2d77bbed2eb50dc37c2ab1 to your computer and use it in GitHub Desktop.
Testing lr_schedule for ddppo
from gym.spaces import Discrete
import numpy as np
import ray
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.agents.ppo import ddppo
from ray.rllib.agents.ppo import DDPPOTrainer
from ray.rllib.utils.framework import try_import_tf
from ray.tune.logger import pretty_print
tf = try_import_tf()
ROCK = 0
PAPER = 1
SCISSORS = 2
class RockPaperScissorsEnv(MultiAgentEnv):
"""Two-player environment for rock paper scissors.
The observation is simply the last opponent action."""
def __init__(self, _):
self.action_space = Discrete(3)
self.observation_space = Discrete(3)
self.player1 = "player1"
self.player2 = "player2"
self.last_move = None
self.num_moves = 0
def reset(self):
self.P1_eps_r = 0
self.P2_eps_r = 0
self.last_move = (0, 0)
self.num_moves = 0
return {
self.player1: self.last_move[1],
self.player2: self.last_move[0],
}
def step(self, action_dict):
move1 = action_dict[self.player1]
move2 = action_dict[self.player2]
self.last_move = (move1, move2)
obs = {
self.player1: self.last_move[1],
self.player2: self.last_move[0],
}
r1, r2 = {
(ROCK, ROCK): (0.0, 0.0),
(ROCK, PAPER): (-1.0, 1.0),
(ROCK, SCISSORS): (1.0, -1.0),
(PAPER, ROCK): (1.0, -1.0),
(PAPER, PAPER): (0.0, 0.0),
(PAPER, SCISSORS): (-1.0, 1.0),
(SCISSORS, ROCK): (-1.0, 1.0),
(SCISSORS, PAPER): (1.0, -1.0),
(SCISSORS, SCISSORS): (0.0, 0.0),
}[move1, move2]
rew = {
self.player1: r1,
self.player2: r2,
}
self.num_moves += 1
done = {
"__all__": self.num_moves >= 10,
}
return obs, rew, done, {}
train_policy_list = ["agt_0", "agt_1"]
use_lstm=False
lr = 0.1
gamma = 0.9
policies = {"agt_0": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm},
"lr": lr,
"gamma": gamma,}),
"agt_1": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm},
"lr": lr,
"gamma": gamma,}),
}
def select_policy(agent_id):
if agent_id == "player1":
return "agt_0"
else:
return "agt_1"
def my_config():
config = ddppo.DEFAULT_CONFIG.copy()
config["multiagent"] = {"policies_to_train": train_policy_list,
"policies": policies,
"policy_mapping_fn": select_policy,
}
config["num_cpus_per_worker"] = 0.25
config["num_gpus_per_worker"] = 0.25
config["num_workers"] = 2
config["log_level"] = "WARN" # WARN/INFO/DEBUG
config["lr_schedule"] = [[0, 0.1], [1000, 0.00001]]
return config
def go_train():
trainer = DDPPOTrainer(config=my_config(), env=RockPaperScissorsEnv)
for i in range(100):
result = trainer.train()
print(pretty_print(result))
ray.init(ignore_reinit_error=True, log_to_driver=True, num_cpus=2, num_gpus=1)
go_train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment