Skip to content

Instantly share code, notes, and snippets.

@kouroshHakha
Last active June 21, 2022 22:30
Show Gist options
  • Save kouroshHakha/cddae0ed5a2af25be2463edbf9174a83 to your computer and use it in GitHub Desktop.
Save kouroshHakha/cddae0ed5a2af25be2463edbf9174a83 to your computer and use it in GitHub Desktop.
Example of how to share arbitrary neural networks between potentially heterogeneous agents in for multi-agent RL in RLlib 1.13
"""For an example on how to share observations between agents in MARL look at examples/centralized_critic_2.py
This gist, gives an example of when we want to share neural network weights between agents.
"""
from typing import Tuple
import ray
from ray.rllib.env.multi_agent_env import MultiAgentEnv, MultiAgentDict
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import torch.nn as nn
import torch
import gym.spaces as spaces
class CustomMA(MultiAgentEnv):
"""Create an illustrative MultiAgent env"""
def __init__(self, config=None):
super().__init__()
self.action_space = spaces.Dict({
'A': spaces.Discrete(2),
'B': spaces.Discrete(3),
})
self.observation_space = spaces.Dict({
'A': spaces.Box(-1, 1, shape=(3,)),
'B': spaces.Box(-1, 1, shape=(2,))
})
self._agent_ids = set(('A', 'B'))
self._spaces_in_preferred_format = True
self._nsteps = 0
def reset(self) -> MultiAgentDict:
self._nsteps = 0
return self.observation_space.sample()
def step(
self, action_dict: MultiAgentDict
) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]:
obs = self.observation_space.sample()
reward = {'A': 0, 'B': 0}
self._nsteps += 1
done = self._nsteps >= 20
dones = {'A': done, 'B': done, '__all__': done}
return obs, reward, dones, {}
class CustomEnvWrapper(MultiAgentEnv):
"""Wrap the env, to make it look like a single agent env with complex
observation space"""
def __init__(self, env_config=None):
super().__init__()
self.env = CustomMA(env_config)
self.observation_space = self._obs_space()
self.action_space = self._act_space()
self._agent_ids = set(self.observation_space.keys())
self._spaces_in_preferred_format = False
def reset(self) -> MultiAgentDict:
obs = self.env.reset()
return self._aug_obs(obs)
def step(
self, action_dict: MultiAgentDict
) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]:
act_out = self._convert_act(action_dict)
obs, reward, done, info = self.env.step(act_out)
return self._aug_obs(obs), reward, done, info
def _obs_space(self):
pids = self.env.get_agent_ids()
obs_spaces = {}
for pid in pids:
others = []
for other in pids:
if other != pid:
others.append(other)
obs_spaces[pid] = spaces.Dict({
'agent': self.env.observation_space[pid],
'neighbors': spaces.Dict({k: self.env.observation_space[k] for k in others})
})
"""
'A': {'agent': obs_space_a, 'neighbors': {'B': obs_space_b}}
'B': ...
"""
return spaces.Dict(obs_spaces)
def _act_space(self):
return self.env.action_space
def _convert_act(self, action_dict):
return action_dict
def _aug_obs(self, obs):
aug_obses_dict = {}
for pid in obs:
others = {}
for other in obs:
if other != pid:
others[other] = obs[other]
aug_obses_dict[pid] = dict(agent=obs[pid], neighbors=others)
return {'agent_wrapper': aug_obses_dict}
class CustomMAModel(TorchModelV2, nn.Module):
"""Create a custom model that is aware of all the policies as if it was a single
agent model
"""
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name
)
nn.Module.__init__(self)
original_space = obs_space.original_space
self.agent_ids = set(original_space.keys())
custom_config = model_config['custom_model_config']
latent_dim = custom_config['latent_dim']
encoder_dict = {
k: nn.Linear(original_space[k]['agent'].shape[0], latent_dim)
for k in self.agent_ids
}
self.encoders = nn.ModuleDict(encoder_dict)
self.shared = nn.Sequential(
nn.Linear(2 * latent_dim, latent_dim),
nn.ReLU(),
nn.Linear(latent_dim, latent_dim),
nn.ReLU(),
)
trunk_dict = {
k: nn.Linear(latent_dim, action_space[k].n)
for k in self.agent_ids
}
self.trunks = nn.ModuleDict(trunk_dict)
self.value_head = nn.Linear(latent_dim, 1)
print('model created')
def forward(
self,
input_dict,
state,
seq_lens,
):
obs = input_dict['obs']
agent_acts = []
for agent_id in self.agent_ids:
agent_obs = obs[agent_id]['agent']
agent_state = self.encoders[agent_id](agent_obs)
neis = obs[agent_id]['neighbors']
nei_states = []
for nei_key, nei_obs in neis.items():
nei_states.append(self.encoders[nei_key](nei_obs))
neighbor_ctx = torch.mean(torch.stack(nei_states, 0), 0)
agent_ctx = torch.cat([agent_state, neighbor_ctx], -1)
agent_latent = self.shared(agent_ctx)
agent_act = self.trunks[agent_id](agent_latent)
agent_acts.append(agent_act)
agent_acts = torch.cat(agent_acts, -1)
self._values = torch.zeros((agent_act.shape[0],)).to(agent_acts.device)
return agent_acts, state
def value_function(self):
return self._values
if __name__ == '__main__':
ray.init(local_mode=True)
ppo_config = (
PPOConfig()
.framework(framework='torch')
.experimental(_disable_action_flattening=True)
.training(model={'custom_model': CustomMAModel,
'custom_model_config': {'latent_dim': 128}})
.environment(env=CustomEnvWrapper, disable_env_checking=True)
.rollouts(num_rollout_workers=0)
# .multi_agent(
# policies={
# 'A': (None, obs_spaces['A'], None, {}),
# 'B': (None, obs_spaces['B'], None, {}),
# },
# policy_mapping_fn=lambda aid, episode, worker, **kw: aid,
# observation_fn=observation_fn,
# )
)
ppo = ppo_config.build()
results = ppo.train()
print(results)
breakpoint()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment