Skip to content

Instantly share code, notes, and snippets.

@danmackinlay
Last active September 23, 2024 06:13
Show Gist options
  • Save danmackinlay/64806fee0bd2554339a861a5091efe2a to your computer and use it in GitHub Desktop.
Save danmackinlay/64806fee0bd2554339a861a5091efe2a to your computer and use it in GitHub Desktop.
Ax + SLURM via `submitit` and `asyncio`
#!/usr/bin/env python
"""
Asynchronous hyperparam search using [Ax](https://ax.dev/) and the submitit executor, to run on SLURM.
Supports resumption of incomplete optimisations from disk, and incremental/partial optimisation, I think.
"""
import os
import asyncio
import time
from submitit import AutoExecutor, LocalJob, DebugJob
import cloudpickle
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.exceptions.generation_strategy import MaxParallelismReachedException
from ax.exceptions.core import DataRequiredError
from ax.core.base_trial import TrialStatus
import numpy as np
def init_or_load_ax_client(
experiment_name,
parameters={},
objectives={},
parameter_constraints={},
ax_save_path="",
resume=False,
**kwargs):
"""
Initializes AxClient from a JSON file if available, otherwise creates a new experiment.
Note that if you load from disk, any parameters will be ignored.
"""
if resume and os.path.exists(ax_save_path):
raise NotImplementedError("Resuming from disk is not yet supported.")
try:
ax_client = AxClient.load_from_json_file(ax_save_path)
print(f"Successfully loaded AxClient state from {ax_save_path}.")
return ax_client
except Exception as e:
print(f"Failed to load AxClient from {ax_save_path}: {e}. Initializing a new AxClient.")
# If the file does not exist or loading failed, create a new AxClient instance
ax_client = AxClient()
ax_client.create_experiment(
name=experiment_name,
parameters=parameters,
objectives=objectives,
parameter_constraints=parameter_constraints,
**kwargs
)
print(f"Created a new experiment: {experiment_name}.")
return ax_client
class JobManager:
def __init__(self, executor, ax_client, ax_save_path, jobs_state_path=None, wait_interval=60):
self.executor = executor
self.ax_client = ax_client
self.ax_save_path = ax_save_path
self.jobs_state_path = jobs_state_path or f"{ax_save_path}.jobs.pkl"
self.jobs = {} # Track jobs by trial_index, storing (job, parameters) tuples
self.save_lock = asyncio.Lock()
self.wait_interval = wait_interval
async def safe_save_state(self):
async with self.save_lock:
try:
# Make sure the directory exists
os.makedirs(os.path.dirname(self.ax_save_path), exist_ok=True)
self.ax_client.save_to_json_file(self.ax_save_path)
print(f"Successfully saved AxClient state to {self.ax_save_path}")
except Exception as e:
print(f"Failed to save AxClient state: {e}")
try:
with open(self.jobs_state_path, 'wb') as f:
cloudpickle.dump(self.jobs, f)
print(f"Successfully saved jobs state to {self.jobs_state_path}")
except Exception as e:
print(f"Failed to save jobs state: {e}")
@staticmethod
def load_state(executor, ax_client, ax_save_path, jobs_state_path=None, wait_interval=30):
jobs_state_path = jobs_state_path or f"{ax_save_path}.jobs.pkl"
if os.path.exists(jobs_state_path):
try:
with open(jobs_state_path, 'rb') as f:
jobs = cloudpickle.load(f)
print(f"Successfully loaded jobs state from {jobs_state_path}")
job_manager = JobManager(executor, ax_client, ax_save_path, jobs_state_path, wait_interval=wait_interval)
job_manager.jobs = jobs
return job_manager
except Exception as e:
print(f"Failed to load jobs state: {e}")
return JobManager(executor, ax_client, ax_save_path, jobs_state_path, wait_interval=wait_interval)
async def process_job(self, fn, parameters, trial_index, is_new=True):
if is_new:
job = self.executor.submit(fn, parameters)
self.jobs[trial_index] = (job, parameters)
else:
job, parameters = self.jobs[trial_index]
# Reattach the trial to the Ax client if necessary
trial = self.ax_client.experiment.trials.get(trial_index)
if trial is None or trial.status.is_failed:
# Reattach the trial and update trial_index if necessary
parameters, new_trial_index = self.ax_client.attach_trial(parameters=parameters)
if new_trial_index != trial_index:
# Update the job's trial index
del self.jobs[trial_index]
self.jobs[new_trial_index] = (job, parameters)
trial_index = new_trial_index
try:
print(f"Processing job {job.job_id} for trial {trial_index}")
result = await job.awaitable().result()
self.ax_client.complete_trial(trial_index=trial_index, raw_data=result)
except Exception as e:
job_stderr = str(e)
self.ax_client.log_trial_failure(trial_index=trial_index, metadata={"stderr": job_stderr})
print(f"Trial {trial_index} failed with error:\n{job_stderr}")
del self.jobs[trial_index]
await self.safe_save_state()
async def process_all_jobs(self, fn=None):
tasks = []
for trial_index, (job, parameters) in list(self.jobs.items()):
# Create a task to process each job
task = asyncio.create_task(self.process_job(fn, parameters, trial_index, is_new=False))
tasks.append(task)
await asyncio.gather(*tasks)
async def run_trials(
fn, executor, ax_client, trial_budget=25,
ax_save_path="experiments/ax_state.json",
job_manager_save_path='experiments/ax_state.jobs.pkl', wait_interval=30,
):
job_manager = JobManager.load_state(
executor, ax_client, ax_save_path, job_manager_save_path, wait_interval=wait_interval
)
# Process all serialized jobs before starting new ones
await job_manager.process_all_jobs(fn)
tasks = []
trials_submitted = 0
while trials_submitted < trial_budget:
try:
parameters, trial_index = ax_client.get_next_trial()
task = asyncio.create_task(job_manager.process_job(fn, parameters, trial_index, is_new=True))
tasks.append(task)
trials_submitted += 1
except (MaxParallelismReachedException, DataRequiredError) as e:
print(f"Waiting for jobs to complete due to: {type(e).__name__}")
await asyncio.sleep(wait_interval)
print(f"Total trials submitted: {trials_submitted}")
# Wait for all tasks to complete
await asyncio.gather(*tasks)
###
# Example usage
###
if __name__ == "__main__":
import matplotlib.pyplot as plt
# import subprocess
import os
from dotenv import load_dotenv
load_dotenv() # take environment variables from `.env` file
# Define the experiment parameters
experiment_name = "optim_hartmann6_dev"
# Define the output directory and ensure it exists
output_dir = os.path.join('experiments', experiment_name)
os.makedirs(output_dir, exist_ok=True)
# Update paths to be relative to output_dir
json_file_path = os.path.join(output_dir, f'{experiment_name}.json')
ax_save_path = os.path.join(output_dir, 'ax_state.json')
job_manager_save_path = os.path.join(output_dir, 'ax_state.jobs.pkl')
executor_folder = os.path.join(output_dir, 'jobs')
def model_trial(params):
"""
A wrapper function which takes the dictionary of parameters from Ax,
runs the synthetic function hartmann6, and returns the results.
"""
x = np.array([params.get(f"x{i+1}") for i in range(6)])
time.sleep(5)
np.random.seed(int(time.time()))
if np.random.rand() < 0.1:
raise ValueError("Randomly failed")
return {
"hartmann6": (hartmann6(x), 0.0),
"l2norm": (np.linalg.norm(x), 0.0)
}
experiment_params = [
{
"name": f"x{i+1}",
"type": "range",
"bounds": [0.0, 1.0]
} for i in range(6)
]
objectives = {
"hartmann6": ObjectiveProperties(minimize=True)
}
ax_client = init_or_load_ax_client(
experiment_name=experiment_name,
json_file_path=json_file_path,
parameters=experiment_params,
objectives=objectives,
tracking_metric_names=["l2norm"],
resume=False,
choose_generation_strategy_kwargs=dict(
no_bayesian_optimization=False,
max_parallelism_override=20,
use_batch_trials=False,
)
)
# Initialize the executor with the updated folder
executor = AutoExecutor(folder=executor_folder)
executor.update_parameters(
timeout_min=119,
slurm_account=os.getenv('SLURM_ACCOUNT'),
slurm_array_parallelism=20,
mem_gb=1, #more often 32
cpus_per_task=1, # more often 8
gpus_per_node=0, # more often 1 or 4
name=experiment_name,
)
# Run the trials with updated paths
asyncio.run(run_trials(
model_trial,
executor,
ax_client,
trial_budget=13,
ax_save_path=ax_save_path,
job_manager_save_path=job_manager_save_path,
))
# Analyze and plot results
best_parameters, values = ax_client.get_best_parameters()
print(f"Best parameters: {best_parameters}, values: {values}")
best_objectives = []
for trial in ax_client.experiment.trials.values():
if trial.status == TrialStatus.COMPLETED:
best_objectives.append(trial.objective_mean)
print(trial.objective_mean)
else:
print(trial.status)
best_objectives = np.array([best_objectives])
plt.figure()
plt.plot(np.minimum.accumulate(best_objectives, axis=1).T, marker="o")
# Save the plot in the output directory
plot_path = os.path.join(output_dir, 'optimization_results.png')
plt.savefig(plot_path)
print(f"Optimization results plot saved to {plot_path}")
# Optionally display the plot using imgcat (if in iTerm2 and imgcat is available)
# subprocess.run(f"imgcat {plot_path}", shell=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment