Last active
August 4, 2020 14:54
-
-
Save ChuaCheowHuan/ef1257789b42e90b33f50a8413cda259 to your computer and use it in GitHub Desktop.
Changing hyperparameter via lr_schedule for MARL with ray[rllib]
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""mod_pol.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
""" | |
""" | |
# Commented out IPython magic to ensure Python compatibility. | |
from google.colab import drive | |
!cat /etc/os-release | |
drive.mount('/content/gdrive') | |
# %cd "/content/gdrive/My Drive/Colab Notebooks/misc_code_examples/ray_colab_examples/rock_paper_scissors_multiagent/" | |
!mkdir chkpt | |
!pip install tensorflow==2.2.0 | |
!pip install ray[rllib]==0.8.6 | |
""" | |
from collections import defaultdict | |
from typing import Dict | |
from gym.spaces import Discrete | |
import numpy as np | |
import argparse | |
import random | |
import ray | |
from ray.tune.logger import pretty_print | |
#from ray.tune.registry import register_env | |
from ray.rllib.models import ModelCatalog | |
from ray.rllib.policy.policy import Policy | |
from ray.rllib.policy.tf_policy import TFPolicy | |
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, DEFAULT_POLICY_ID | |
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy | |
#from ray.rllib.agents.ppo.ppo_tf_policy import LearningRateSchedule | |
from ray.rllib.agents.ppo import ppo | |
from ray.rllib.agents.ppo import PPOTrainer | |
from ray.rllib.agents.callbacks import DefaultCallbacks | |
from ray.rllib.env import BaseEnv | |
from ray.rllib.env.multi_agent_env import MultiAgentEnv | |
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker | |
from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing | |
from ray.rllib.utils.framework import try_import_tf | |
from ray.rllib.utils.annotations import override, DeveloperAPI | |
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule | |
from ray.rllib.utils.schedules.schedule import Schedule | |
tf = try_import_tf() | |
ROCK = 0 | |
PAPER = 1 | |
SCISSORS = 2 | |
class RockPaperScissorsEnv(MultiAgentEnv): | |
"""Two-player environment for rock paper scissors. | |
The observation is simply the last opponent action.""" | |
def __init__(self, _): | |
self.action_space = Discrete(3) | |
self.observation_space = Discrete(3) | |
self.player1 = "player1" | |
self.player2 = "player2" | |
self.last_move = None | |
self.num_moves = 0 | |
self.P1_eps_r = 0 | |
self.P2_eps_r = 0 | |
def reset(self): | |
#print("P1_eps_r =", self.P1_eps_r) | |
#print("P2_eps_r =", self.P2_eps_r) | |
self.P1_eps_r = 0 | |
self.P2_eps_r = 0 | |
self.last_move = (0, 0) | |
self.num_moves = 0 | |
return { | |
self.player1: self.last_move[1], | |
self.player2: self.last_move[0], | |
} | |
def step(self, action_dict): | |
move1 = action_dict[self.player1] | |
move2 = action_dict[self.player2] | |
self.last_move = (move1, move2) | |
obs = { | |
self.player1: self.last_move[1], | |
self.player2: self.last_move[0], | |
} | |
r1, r2 = { | |
(ROCK, ROCK): (0.0, 0.0), | |
(ROCK, PAPER): (-1.0, 1.0), | |
(ROCK, SCISSORS): (1.0, -1.0), | |
(PAPER, ROCK): (1.0, -1.0), | |
(PAPER, PAPER): (0.0, 0.0), | |
(PAPER, SCISSORS): (-1.0, 1.0), | |
(SCISSORS, ROCK): (-1.0, 1.0), | |
(SCISSORS, PAPER): (1.0, -1.0), | |
(SCISSORS, SCISSORS): (0.0, 0.0), | |
}[move1, move2] | |
rew = { | |
self.player1: r1, | |
self.player2: r2, | |
} | |
self.num_moves += 1 | |
done = { | |
"__all__": self.num_moves >= 10, | |
} | |
#print("rew =", rew) | |
self.P1_eps_r += rew[self.player1] | |
self.P2_eps_r += rew[self.player2] | |
return obs, rew, done, {} | |
train_policy_list = ["agt_0", "agt_1"] | |
use_lstm=True | |
lr = 0.0 | |
gamma = 0.9 | |
policies = {"agt_0": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}, | |
"lr": lr, | |
"gamma": gamma, | |
"framework": "tf",}), | |
"agt_1": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}, | |
"lr": lr, | |
"gamma": gamma, | |
"framework": "tf",}), | |
} | |
def select_policy(agent_id): | |
if agent_id == "player1": | |
return "agt_0" | |
else: | |
return "agt_1" | |
def my_config(): | |
config = ppo.DEFAULT_CONFIG.copy() | |
config["multiagent"] = {"policies_to_train": train_policy_list, | |
"policies": policies, | |
"policy_mapping_fn": select_policy, | |
} | |
config["num_cpus_per_worker"] = 0.25 | |
#config["num_gpus_per_worker"] = 0.25 | |
config["num_workers"] = 2 | |
config["num_envs_per_worker"] = 2 | |
config["batch_mode"] = "truncate_episodes" # "complete_episodes" or "truncate_episodes" | |
config["rollout_fragment_length"] = 10 # let's sample 10 steps per episode which is the same as batch_mode="complete_episodes" | |
config["train_batch_size"] = 10 # Training batch size, if applicable. Should be >= rollout_fragment_length. | |
# Samples batches will be concatenated together to a batch of this size, | |
# which is then passed to SGD. | |
# If batch_mode is "complete_episodes", | |
config["sgd_minibatch_size"] = 10 # default=128, sgd_minibatch_size, must be <= train_batch_size. | |
config["num_sgd_iter"] = 3 # default=30, number of epochs to execute per train batch. | |
config["log_level"] = "WARN" # WARN/INFO/DEBUG | |
#config["callbacks"] = MyCallbacks | |
config["lr_schedule"] = None | |
#config["lr_schedule"] = [[0, 0.1], [1000, 0.001]] | |
return config | |
def go_train(): | |
trainer = PPOTrainer(config=my_config(), env=RockPaperScissorsEnv) | |
is_hyp_chg = False | |
for i in range(300): | |
result = trainer.train() | |
if i > 30 and is_hyp_chg == False: | |
# hard coded "new" hyperparameters: | |
lr = 0.001 | |
#gamma = 0.9 | |
#agt_0_pol = trainer.get_policy('agt_0') | |
#agt_0_pol.lr_schedule = ConstantSchedule(lr, framework=None) | |
agt_1_pol = trainer.get_policy('agt_1') | |
agt_1_pol.lr_schedule = ConstantSchedule(np.random.rand(), framework=None) | |
#is_hyp_chg = True | |
#if i % 10 == 0: | |
# print(pretty_print(result)) | |
print(pretty_print(result)) | |
#register_env("RockPaperScissorsEnv", lambda _: RockPaperScissorsEnv(_)) | |
ray.shutdown() | |
ray.init(ignore_reinit_error=True, log_to_driver=True, webui_host='127.0.0.1', num_cpus=2, num_gpus=0) #start ray | |
go_train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment