Last active
June 25, 2019 01:04
-
-
Save alex-petrenko/5cf4686e6494ad3260c87f00d27b7e49 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VizdoomEnvMultiplayer(VizdoomEnv): | |
def __init__(self, level, player_id, num_players, skip_frames, level_map='map01'): | |
super().__init__(level, skip_frames=skip_frames, level_map=level_map) | |
self.player_id = player_id | |
self.num_players = num_players | |
self.timestep = 0 | |
self.update_state = True | |
def _is_server(self): | |
return self.player_id == 0 | |
def _ensure_initialized(self, mode='algo'): | |
if self.initialized: | |
# Doom env already initialized! | |
return | |
self.game = DoomGame() | |
self.game.load_config(self.config_path) | |
self.game.set_screen_resolution(self.screen_resolution) | |
# Setting an invalid level map will cause the game to freeze silently | |
self.game.set_doom_map(self.level_map) | |
self.game.set_seed(self.rng.random_integers(0, 2**32-1)) | |
if mode == 'algo': | |
self.game.set_window_visible(False) | |
if self._is_server(): | |
# This process will function as a host for a multiplayer game with this many players (including the host). | |
# It will wait for other machines to connect using the -join parameter and then | |
# start the game when everyone is connected. | |
self.game.add_game_args( | |
f'-host {self.num_players} ' | |
'-deathmatch ' # Deathmatch rules are used for the game. | |
'+timelimit 10.0 ' # The game (episode) will end after this many minutes have elapsed. | |
'+sv_forcerespawn 1 ' # Players will respawn automatically after they die. | |
'+sv_noautoaim 1 ' # Autoaim is disabled for all players. | |
'+sv_respawnprotect 1 ' # Players will be invulnerable for two second after spawning. | |
'+sv_spawnfarthest 1 ' # Players will be spawned as far as possible from any other players. | |
'+sv_nocrouch 1 ' # Disables crouching. | |
'+viz_respawn_delay 1 ' # Sets delay between respanws (in seconds). | |
'+viz_nocheat 1', # Disables depth and labels buffer and the ability to use commands | |
# that could interfere with multiplayer game. | |
) | |
# Name your agent and select color | |
# colors: | |
# 0 - green, 1 - gray, 2 - brown, 3 - red, 4 - light gray, 5 - light brown, 6 - light red, 7 - light blue | |
self.game.add_game_args('+name Host +colorset 0') | |
else: | |
# TODO: port, name | |
# Join existing game. | |
self.game.add_game_args('-join 127.0.0.1') # Connect to a host for a multiplayer game. | |
# Name your agent and select color | |
# colors: | |
# 0 - green, 1 - gray, 2 - brown, 3 - red, 4 - light gray, 5 - light brown, 6 - light red, 7 - light blue | |
self.game.add_game_args('+name AI +colorset 0') | |
self.game.set_mode(Mode.PLAYER) | |
self.game.init() | |
self.initialized = True | |
def reset(self, mode='algo'): | |
self._ensure_initialized(mode) | |
self.timestep = 0 | |
self.update_state = True | |
self.game.new_episode() | |
self.state = self.game.get_state() | |
img = self.state.screen_buffer | |
return np.transpose(img, (1, 2, 0)) | |
def step(self, action): | |
self._ensure_initialized() | |
info = {'num_frames': self.skip_frames} | |
# convert action to vizdoom action space (one hot) | |
act = np.zeros(self.action_space.n) | |
act[action] = 1 | |
act = np.uint8(act) | |
act = act.tolist() | |
reward = 0 | |
self.game.set_action(act) | |
self.game.advance_action(1, self.update_state) | |
reward += self.game.get_last_reward() | |
self.timestep += 1 | |
if not self.update_state: | |
return None, None, None, None | |
state = self.game.get_state() | |
done = self.game.is_episode_finished() | |
if not done: | |
observation = np.transpose(state.screen_buffer, (1, 2, 0)) | |
game_variables = self._game_variables_dict(state) | |
info.update(self.get_info(game_variables)) | |
else: | |
observation = np.zeros(self.observation_space.shape, dtype=np.uint8) | |
return observation, reward, done, info | |
def safe_get(q, timeout=1e6, msg='Queue timeout'): | |
"""Using queue.get() with timeout is necessary, otherwise KeyboardInterrupt is not handled.""" | |
while True: | |
try: | |
return q.get(timeout=timeout) | |
except Empty: | |
log.exception(msg) | |
class TaskType(Enum): | |
INIT, TERMINATE, RESET, STEP, STEP_UPDATE, INFO = range(6) | |
class MultiAgentEnvWorker: | |
def __init__(self, player_id, num_players, make_env_func): | |
self.player_id = player_id | |
self.num_players = num_players | |
self.make_env_func = make_env_func | |
self.task_queue, self.result_queue = JoinableQueue(), JoinableQueue() | |
self.process = Process(target=self.start, daemon=True) | |
self.process.start() | |
def _init(self): | |
log.info('Initializing env for player %d...', self.player_id) | |
env = self.make_env_func(player_id=self.player_id, num_players=self.num_players) | |
env.seed(self.player_id) | |
return env | |
def _terminate(self, env): | |
log.info('Stop env for player %d...', self.player_id) | |
env.close() | |
log.info('Env with player %d terminated!', self.player_id) | |
@staticmethod | |
def _get_info(env): | |
"""Specific to custom VizDoom environments.""" | |
info = {} | |
if hasattr(env.unwrapped, 'get_info_all'): | |
info = env.unwrapped.get_info_all() # info for the new episode | |
return info | |
def start(self): | |
env = None | |
while True: | |
action, task_type = safe_get(self.task_queue) | |
if task_type == TaskType.INIT: | |
env = self._init() | |
self.task_queue.task_done() | |
continue | |
if task_type == TaskType.TERMINATE: | |
self._terminate(env) | |
self.task_queue.task_done() | |
break | |
if task_type == TaskType.RESET: | |
results = env.reset() | |
elif task_type == TaskType.INFO: | |
results = self._get_info(env) | |
elif task_type == TaskType.STEP or task_type == TaskType.STEP_UPDATE: | |
# collect obs, reward, done, and info | |
env.unwrapped.update_state = task_type == TaskType.STEP_UPDATE | |
results = env.step(action) | |
else: | |
raise Exception(f'Unknown task type {task_type}') | |
self.result_queue.put(results) | |
self.task_queue.task_done() | |
class VizdoomMultiAgentEnv: | |
def __init__(self, num_players, make_env_func, env_config): | |
self.num_players = num_players | |
self.skip_frames = 4 | |
env = make_env_func(player_id=-1, num_players=num_players) # temporary | |
self.action_space = env.action_space | |
self.observation_space = env.observation_space | |
env.close() | |
self.workers = [MultiAgentEnvWorker(i, num_players, make_env_func) for i in range(num_players)] | |
for worker in self.workers: | |
worker.task_queue.put((None, TaskType.INIT)) | |
time.sleep(0.1) # just in case | |
for worker in self.workers: | |
worker.task_queue.join() | |
log.info('%d agent workers initialized!', len(self.workers)) | |
def await_tasks(self, data, task_type, timeout=None): | |
""" | |
Task result is always a tuple of dicts, e.g.: | |
( | |
{'0': 0th_agent_obs, '1': 1st_agent_obs, ... , | |
{'0': 0th_agent_reward, '1': 1st_agent_obs, ... , | |
... | |
) | |
If your "task" returns only one result per agent (e.g. reset() returns only the observation), | |
the result will be a tuple of lenght 1. It is a responsibility of the caller to index appropriately. | |
""" | |
if data is None: | |
data = {str(i): None for i in range(self.num_players)} | |
assert len(data) == self.num_players | |
for i, worker in enumerate(self.workers[1:], start=1): | |
worker.task_queue.put((data[str(i)], task_type)) | |
self.workers[0].task_queue.put((data[str(0)], task_type)) | |
result_dicts = None | |
for i, worker in enumerate(self.workers): | |
worker.task_queue.join() | |
results = safe_get( | |
worker.result_queue, | |
timeout=0.02 if timeout is None else timeout, | |
msg=f'Takes a surprisingly long time to process task {task_type}, retry...', | |
) | |
worker.result_queue.task_done() | |
if not isinstance(results, (tuple, list)): | |
results = [results] | |
if result_dicts is None: | |
result_dicts = tuple({} for _ in results) | |
for j, r in enumerate(results): | |
result_dicts[j][str(i)] = r | |
return result_dicts | |
def info(self): | |
info = self.await_tasks(None, TaskType.INFO)[0] | |
return info | |
def reset(self): | |
observation = self.await_tasks(None, TaskType.RESET)[0] | |
return observation | |
def step(self, actions): | |
for frame in range(self.skip_frames - 1): | |
self.await_tasks(actions, TaskType.STEP) | |
obs, rew, dones, infos = self.await_tasks(actions, TaskType.STEP_UPDATE) | |
dones['__all__'] = all(dones.values()) | |
return obs, rew, dones, infos | |
def close(self): | |
log.info('Stopping multi env...') | |
for worker in self.workers: | |
worker.task_queue.put((None, TaskType.TERMINATE)) | |
time.sleep(0.1) | |
for worker in self.workers: | |
worker.process.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment