Last active
September 27, 2020 09:48
-
-
Save willwhitney/e1509c86522896c6930d2fe9ea49a522 to your computer and use it in GitHub Desktop.
Script for running grids of experiments on slurm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#---------------------------------------------- | |
# things to change: | |
# code_dir (the full path of the directory that contains your source dir) | |
# true_source_dir (change it from TD3 to whatever your source dir is called) | |
# job_source_dir (someplace to throw a duplicate of the source dir for this job) | |
#---------------------------------------------- | |
import os | |
import sys | |
import itertools | |
dry_run = '--dry-run' in sys.argv | |
clear = '--clear' in sys.argv | |
if not os.path.exists("slurm_logs"): | |
os.makedirs("slurm_logs") | |
if not os.path.exists("slurm_scripts"): | |
os.makedirs("slurm_scripts") | |
code_dir = '/private/home/willwhitney/code' | |
basename = "PFnew_start_traj1" | |
grids = [ | |
# raw | |
{ | |
"main_file": ['main'], | |
"env_name": [ | |
'Pusher-v2', | |
'Striker-v2', | |
'Thrower-v2', | |
], | |
# "start_timesteps": [0], | |
"max_timesteps": [1e7], | |
"eval_freq": [5e3], | |
"render_freq": [1e5], | |
"seed": list(range(8)), | |
}, | |
# learned embedding | |
{ | |
"main_file": ['main_embedded'], | |
"env_name": [ | |
'Pusher-v2', | |
'Striker-v2', | |
'Thrower-v2', | |
], | |
"decoder": [ | |
# "white_qvel_traj8_z7", | |
"white_qvel_traj1_z7", | |
], | |
# "start_timesteps": [0], | |
"max_timesteps": [1e7], | |
"eval_freq": [5e3], | |
"render_freq": [1e5], | |
"seed": list(range(8)), | |
}, | |
] | |
jobs = [] | |
for grid in grids: | |
individual_options = [[{key: value} for value in values] | |
for key, values in grid.items()] | |
product_options = list(itertools.product(*individual_options)) | |
jobs += [{k: v for d in option_set for k, v in d.items()} | |
for option_set in product_options] | |
if dry_run: | |
print("NOT starting {} jobs:".format(len(jobs))) | |
else: | |
print("Starting {} jobs:".format(len(jobs))) | |
all_keys = set().union(*[g.keys() for g in grids]) | |
merged = {k: set() for k in all_keys} | |
for grid in grids: | |
for key in all_keys: | |
grid_key_value = grid[key] if key in grid else ["<<NONE>>"] | |
merged[key] = merged[key].union(grid_key_value) | |
varying_keys = {key for key in merged if len(merged[key]) > 1} | |
excluded_flags = {'main_file'} | |
for job in jobs: | |
jobname = basename | |
flagstring = "" | |
for flag in job: | |
# construct the string of arguments to be passed to the script | |
if not flag in excluded_flags: | |
if isinstance(job[flag], bool): | |
if job[flag]: | |
flagstring = flagstring + " --" + flag | |
else: | |
print("WARNING: Excluding 'False' flag " + flag) | |
else: | |
flagstring = flagstring + " --" + flag + " " + str(job[flag]) | |
# construct the job's name | |
if flag in varying_keys: | |
jobname = jobname + "_" + flag + str(job[flag]) | |
flagstring = flagstring + " --name " + jobname | |
slurm_script_path = 'slurm_scripts/' + jobname + '.slurm' | |
slurm_script_dir = os.path.dirname(slurm_script_path) | |
os.makedirs(slurm_script_dir, exist_ok=True) | |
slurm_log_dir = 'slurm_logs/' + jobname | |
os.makedirs(os.path.dirname(slurm_log_dir), exist_ok=True) | |
true_source_dir = code_dir + '/TD3' | |
job_source_dir = code_dir + '/TD3-clones/' + jobname | |
try: | |
os.makedirs(job_source_dir) | |
os.system('cp -R ./* ' + job_source_dir) | |
except FileExistsError: | |
# with the 'clear' flag, we're starting fresh | |
# overwrite the code that's already here | |
if clear: | |
print("Overwriting existing files.") | |
os.system('cp -R ./* ' + job_source_dir) | |
jobcommand = "python {}/{}.py{}".format(job_source_dir, job['main_file'], flagstring) | |
job_start_command = "sbatch " + slurm_script_path | |
# jobcommand += " --restart-command '{}'".format(job_start_command) | |
print(jobcommand) | |
with open(slurm_script_path, 'w') as slurmfile: | |
slurmfile.write("#!/bin/bash\n") | |
slurmfile.write("#SBATCH --job-name" + "=" + jobname + "\n") | |
slurmfile.write("#SBATCH --open-mode=append\n") | |
slurmfile.write("#SBATCH --output=slurm_logs/" + | |
jobname + ".out\n") | |
slurmfile.write("#SBATCH --error=slurm_logs/" + jobname + ".err\n") | |
slurmfile.write("#SBATCH --export=ALL\n") | |
slurmfile.write("#SBATCH --signal=USR1@600\n") | |
slurmfile.write("#SBATCH --time=1-00\n") | |
slurmfile.write("#SBATCH -N 1\n") | |
slurmfile.write("#SBATCH --mem=32gb\n") | |
slurmfile.write("#SBATCH -c 4\n") | |
slurmfile.write("#SBATCH --gres=gpu:1\n") | |
slurmfile.write("cd " + true_source_dir + '\n') | |
slurmfile.write("srun " + jobcommand) | |
slurmfile.write("\n") | |
if not dry_run: | |
os.system(job_start_command + " &") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment