Last active
October 26, 2017 06:40
-
-
Save pekaalto/2d465fea6db3d58b0de75709dea61623 to your computer and use it in GitHub Desktop.
Just editing baselines subprocvecenv for sc2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Almost direct copy from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py | |
""" | |
from multiprocessing import Process, Pipe | |
from pysc2.env import sc2_env, available_actions_printer | |
def worker(remote, env_fn_wrapper): | |
""" | |
Handling the: | |
action -> [action] and [timestep] -> timestep | |
single-player conversions here | |
""" | |
env = env_fn_wrapper.x() | |
while True: | |
cmd, action = remote.recv() | |
if cmd == 'step': | |
timesteps = env.step([action]) | |
assert len(timesteps) == 1 | |
remote.send(timesteps[0]) | |
elif cmd == 'reset': | |
timesteps = env.reset() | |
assert len(timesteps) == 1 | |
remote.send(timesteps[0]) | |
elif cmd == 'close': | |
remote.close() | |
break | |
else: | |
raise NotImplementedError | |
class SC2VecEnv: | |
def __init__(self, env_fns): | |
n_envs = len(env_fns) | |
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(n_envs)]) | |
self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn))) | |
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)] | |
for p in self.ps: | |
p.start() | |
self.n_envs = n_envs | |
def _step_or_reset(self, command, actions=None): | |
actions = actions or [None] * self.n_envs | |
for remote, action in zip(self.remotes, actions): | |
remote.send((command, action)) | |
timesteps = [remote.recv() for remote in self.remotes] | |
return timesteps | |
def step(self, actions): | |
return self._step_or_reset("step", actions) | |
def reset(self): | |
return self._step_or_reset("reset", None) | |
def close(self): | |
for remote in self.remotes: | |
remote.send(('close', None)) | |
for p in self.ps: | |
p.join() | |
def reset_done_envs(self): | |
pass | |
def make_sc2env(**kwargs): | |
env = sc2_env.SC2Env(**kwargs) | |
# env = available_actions_printer.AvailableActionsPrinter(env) | |
return env | |
# This is the original baselines one | |
class CloudpickleWrapper(object): | |
""" | |
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) | |
""" | |
def __init__(self, x): | |
self.x = x | |
def __getstate__(self): | |
import cloudpickle | |
return cloudpickle.dumps(self.x) | |
def __setstate__(self, ob): | |
import pickle | |
self.x = pickle.loads(ob) | |
""" | |
# To use do something like this | |
from functools import partial | |
env_args = dict( | |
map_name=FLAGS.map_name, | |
step_mul=FLAGS.step_mul, | |
game_steps_per_episode=0, | |
screen_size_px=(FLAGS.resolution,) * 2, | |
minimap_size_px=(FLAGS.resolution,) * 2, | |
visualize=FLAGS.render | |
) | |
envs = SC2VecEnv((partial(make_sc2env, **env_args),) * FLAGS.n_envs) | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment