Instantly share code, notes, and snippets.
Last active
June 21, 2022 22:30
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
Save kouroshHakha/cddae0ed5a2af25be2463edbf9174a83 to your computer and use it in GitHub Desktop.
Example of how to share arbitrary neural networks between potentially heterogeneous agents in for multi-agent RL in RLlib 1.13
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""For an example on how to share observations between agents in MARL look at examples/centralized_critic_2.py | |
This gist, gives an example of when we want to share neural network weights between agents. | |
""" | |
from typing import Tuple | |
import ray | |
from ray.rllib.env.multi_agent_env import MultiAgentEnv, MultiAgentDict | |
from ray.rllib.algorithms.ppo import PPOConfig | |
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 | |
import torch.nn as nn | |
import torch | |
import gym.spaces as spaces | |
class CustomMA(MultiAgentEnv): | |
"""Create an illustrative MultiAgent env""" | |
def __init__(self, config=None): | |
super().__init__() | |
self.action_space = spaces.Dict({ | |
'A': spaces.Discrete(2), | |
'B': spaces.Discrete(3), | |
}) | |
self.observation_space = spaces.Dict({ | |
'A': spaces.Box(-1, 1, shape=(3,)), | |
'B': spaces.Box(-1, 1, shape=(2,)) | |
}) | |
self._agent_ids = set(('A', 'B')) | |
self._spaces_in_preferred_format = True | |
self._nsteps = 0 | |
def reset(self) -> MultiAgentDict: | |
self._nsteps = 0 | |
return self.observation_space.sample() | |
def step( | |
self, action_dict: MultiAgentDict | |
) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]: | |
obs = self.observation_space.sample() | |
reward = {'A': 0, 'B': 0} | |
self._nsteps += 1 | |
done = self._nsteps >= 20 | |
dones = {'A': done, 'B': done, '__all__': done} | |
return obs, reward, dones, {} | |
class CustomEnvWrapper(MultiAgentEnv): | |
"""Wrap the env, to make it look like a single agent env with complex | |
observation space""" | |
def __init__(self, env_config=None): | |
super().__init__() | |
self.env = CustomMA(env_config) | |
self.observation_space = self._obs_space() | |
self.action_space = self._act_space() | |
self._agent_ids = set(self.observation_space.keys()) | |
self._spaces_in_preferred_format = False | |
def reset(self) -> MultiAgentDict: | |
obs = self.env.reset() | |
return self._aug_obs(obs) | |
def step( | |
self, action_dict: MultiAgentDict | |
) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]: | |
act_out = self._convert_act(action_dict) | |
obs, reward, done, info = self.env.step(act_out) | |
return self._aug_obs(obs), reward, done, info | |
def _obs_space(self): | |
pids = self.env.get_agent_ids() | |
obs_spaces = {} | |
for pid in pids: | |
others = [] | |
for other in pids: | |
if other != pid: | |
others.append(other) | |
obs_spaces[pid] = spaces.Dict({ | |
'agent': self.env.observation_space[pid], | |
'neighbors': spaces.Dict({k: self.env.observation_space[k] for k in others}) | |
}) | |
""" | |
'A': {'agent': obs_space_a, 'neighbors': {'B': obs_space_b}} | |
'B': ... | |
""" | |
return spaces.Dict(obs_spaces) | |
def _act_space(self): | |
return self.env.action_space | |
def _convert_act(self, action_dict): | |
return action_dict | |
def _aug_obs(self, obs): | |
aug_obses_dict = {} | |
for pid in obs: | |
others = {} | |
for other in obs: | |
if other != pid: | |
others[other] = obs[other] | |
aug_obses_dict[pid] = dict(agent=obs[pid], neighbors=others) | |
return {'agent_wrapper': aug_obses_dict} | |
class CustomMAModel(TorchModelV2, nn.Module): | |
"""Create a custom model that is aware of all the policies as if it was a single | |
agent model | |
""" | |
def __init__(self, obs_space, action_space, num_outputs, model_config, name): | |
TorchModelV2.__init__( | |
self, obs_space, action_space, num_outputs, model_config, name | |
) | |
nn.Module.__init__(self) | |
original_space = obs_space.original_space | |
self.agent_ids = set(original_space.keys()) | |
custom_config = model_config['custom_model_config'] | |
latent_dim = custom_config['latent_dim'] | |
encoder_dict = { | |
k: nn.Linear(original_space[k]['agent'].shape[0], latent_dim) | |
for k in self.agent_ids | |
} | |
self.encoders = nn.ModuleDict(encoder_dict) | |
self.shared = nn.Sequential( | |
nn.Linear(2 * latent_dim, latent_dim), | |
nn.ReLU(), | |
nn.Linear(latent_dim, latent_dim), | |
nn.ReLU(), | |
) | |
trunk_dict = { | |
k: nn.Linear(latent_dim, action_space[k].n) | |
for k in self.agent_ids | |
} | |
self.trunks = nn.ModuleDict(trunk_dict) | |
self.value_head = nn.Linear(latent_dim, 1) | |
print('model created') | |
def forward( | |
self, | |
input_dict, | |
state, | |
seq_lens, | |
): | |
obs = input_dict['obs'] | |
agent_acts = [] | |
for agent_id in self.agent_ids: | |
agent_obs = obs[agent_id]['agent'] | |
agent_state = self.encoders[agent_id](agent_obs) | |
neis = obs[agent_id]['neighbors'] | |
nei_states = [] | |
for nei_key, nei_obs in neis.items(): | |
nei_states.append(self.encoders[nei_key](nei_obs)) | |
neighbor_ctx = torch.mean(torch.stack(nei_states, 0), 0) | |
agent_ctx = torch.cat([agent_state, neighbor_ctx], -1) | |
agent_latent = self.shared(agent_ctx) | |
agent_act = self.trunks[agent_id](agent_latent) | |
agent_acts.append(agent_act) | |
agent_acts = torch.cat(agent_acts, -1) | |
self._values = torch.zeros((agent_act.shape[0],)).to(agent_acts.device) | |
return agent_acts, state | |
def value_function(self): | |
return self._values | |
if __name__ == '__main__': | |
ray.init(local_mode=True) | |
ppo_config = ( | |
PPOConfig() | |
.framework(framework='torch') | |
.experimental(_disable_action_flattening=True) | |
.training(model={'custom_model': CustomMAModel, | |
'custom_model_config': {'latent_dim': 128}}) | |
.environment(env=CustomEnvWrapper, disable_env_checking=True) | |
.rollouts(num_rollout_workers=0) | |
# .multi_agent( | |
# policies={ | |
# 'A': (None, obs_spaces['A'], None, {}), | |
# 'B': (None, obs_spaces['B'], None, {}), | |
# }, | |
# policy_mapping_fn=lambda aid, episode, worker, **kw: aid, | |
# observation_fn=observation_fn, | |
# ) | |
) | |
ppo = ppo_config.build() | |
results = ppo.train() | |
print(results) | |
breakpoint() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment