Created
September 21, 2021 11:20
-
-
Save heiner/67955538cce6375dbdd6fa97ffd11ce9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import queue | |
import threading | |
import gym | |
def target(resetqueue, readyqueue): | |
while True: | |
env = resetqueue.get() | |
if env is None: | |
return | |
obs = env.reset() | |
readyqueue.put((obs, env)) | |
class CachedEnvWrapper(gym.Env): | |
def __init__(self, envs, num_threads=2): | |
self._envs = envs | |
# This could alternatively also use concurrent.futures. I hesitate to do | |
# that as futures.wait would have me deal with sets all the time where they | |
# are really not necessary. | |
self._resetqueue = queue.SimpleQueue() | |
self._readyqueue = queue.SimpleQueue() | |
self._threads = [ | |
threading.Thread(target=target, args=(self._resetqueue, self._readyqueue)) | |
for _ in range(num_threads) | |
] | |
for t in self._threads: | |
t.start() | |
for env in envs[1:]: | |
self._resetqueue.put(env) | |
self._env = envs[0] | |
def reset(self): | |
self._resetqueue.put(self._env) | |
obs, self._env = self._readyqueue.get() | |
return obs | |
def step(self, action): | |
return self._env.step(action) | |
def close(self): | |
for _ in self._threads: | |
self._resetqueue.put(None) | |
for t in self._threads: | |
t.join() | |
for env in self._envs: | |
env.close() | |
def seed(self, seed=None): | |
self._env.seed(seed) # Unclear if this should happen in all envs? | |
def unwrapped(self): | |
return self._env | |
def __str__(self): | |
return "<CachedEnvWrapper envs=%s>" % [str(env) for env in self._envs] | |
def __enter__(self): | |
return self | |
def __exit__(self, *args): | |
self.close() | |
return False # Propagate exception. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from concurrent import futures | |
import gym | |
def target(env): | |
obs = env.reset() | |
return obs, env | |
class CachedEnvWrapper2(gym.Env): | |
def __init__(self, envs, threadpool, num_workers=2): | |
self._envs = envs | |
self._threadpool = threadpool | |
self._num_workers = 2 | |
self._futures = set() | |
self._env = envs[0] | |
for env in envs[1:]: | |
self._futures.add(threadpool.submit(target, env)) | |
def step(self, action): | |
return self._env.step(action) | |
def reset(self): | |
self._futures.add(self._threadpool.submit(target, self._env)) | |
done, not_done = futures.wait( | |
self._futures, return_when=futures.FIRST_COMPLETED | |
) | |
for future in done: | |
obs, self._env = future.result() | |
break | |
self._futures.remove(future) | |
return obs | |
def close(self): | |
for env in self._envs: | |
env.close() | |
def seed(self, seed=None): | |
self._env.seed(seed) # Unclear if this should happen in all envs? | |
def unwrapped(self): | |
return self._env | |
def __str__(self): | |
return "<CachedEnvWrapper envs=%s>" % [str(env) for env in self._envs] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment