-
-
Save bjuergens/baceb62a68c113a3ab770184629f2cb1 to your computer and use it in GitHub Desktop.
quicktest for combination of procgen envs and different multiprocessinghandlers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import tap | |
import gc | |
import logging | |
from multiprocessing import Pool | |
import os | |
from dask.distributed import Client, LocalCluster | |
class Params(tap.Tap): | |
n: int = 100 # now often? | |
info: int = 0 # now often to print messages? | |
gc_force: bool = False # should gc be called after ever work? | |
steps: int = 0 # should steps be done in the env? If so how many? | |
env: str = "procgen:procgen-heist-v0" | |
ph: str = "sequence" # sequence or mp or dask | |
def work(steps, gc_force, info, env_id, i): | |
if info: | |
if i % info == 0: | |
print(str(i)) | |
env = gym.make(env_id, | |
distribution_mode="memory", | |
use_monochrome_assets=False, | |
restrict_themes=True, | |
use_backgrounds=False) | |
if steps: | |
env.reset() | |
for _ in range(steps): | |
env.step(env.action_space.sample()) | |
if gc_force: | |
logging.info("calling gc.collect()...") | |
gc.collect() | |
def main(args): | |
logging.info("starting " + str(args.n) + " iterations") | |
if args.ph == "sequence": | |
logging.info("starting singlethreaded...") | |
for i in range(args.n): | |
work(args.steps, args.gc_force, args.info, args.env, i) | |
elif args.ph == "mp": | |
logging.info("starting with mp...") | |
with Pool(os.cpu_count()) as pool: | |
params = [] | |
for i in range(args.n): | |
toup = (args.steps, args.gc_force, args.info, args.env, i) | |
params.append(toup) | |
pool.starmap(work, params) | |
elif args.ph == "dask": | |
cluster = LocalCluster(processes=True, asynchronous=False, threads_per_worker=1, n_workers=args.n, | |
memory_pause_fraction=False, interface="lo") | |
client = Client(cluster) | |
logging.info("Dask dashboard available at port: " + str(client.scheduler_info()["services"]["dashboard"])) | |
params = [[], [], [], [], []] | |
for i in range(args.n): | |
params[0].append(args.steps) | |
params[1].append(args.gc_force) | |
params[2].append(args.info) | |
params[3].append(args.env) | |
params[4].append(i) | |
client.gather(client.map(work, *params)) | |
else: | |
raise RuntimeError("unknown value for ph: " + str(args.ph)) | |
logging.info("done") | |
if __name__ == "__main__": | |
main(Params(underscores_to_dashes=True).parse_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment