Asynchronous hyperparam search using Ax and the submitit executor, to run on SLURM.
Supports resumption of incomplete optimisations from disk, and incremental/partial optimisation, I think.
Note: slurm/submitit is now officially supported by Ax
Refs
Asynchronous hyperparam search using Ax and the submitit executor, to run on SLURM.
Supports resumption of incomplete optimisations from disk, and incremental/partial optimisation, I think.
Note: slurm/submitit is now officially supported by Ax
Refs
| #!/usr/bin/env python | |
| """ | |
| Asynchronous hyperparam search using [Ax](https://ax.dev/) and the submitit executor, to run on SLURM. | |
| Supports resumption of incomplete optimisations from disk, and incremental/partial optimisation, I think. | |
| """ | |
| import os | |
| import asyncio | |
| import time | |
| from submitit import AutoExecutor, LocalJob, DebugJob | |
| import cloudpickle | |
| from ax.service.ax_client import AxClient, ObjectiveProperties | |
| from ax.utils.measurement.synthetic_functions import hartmann6 | |
| from ax.exceptions.generation_strategy import MaxParallelismReachedException | |
| from ax.exceptions.core import DataRequiredError | |
| from ax.core.base_trial import TrialStatus | |
| import numpy as np | |
| def init_or_load_ax_client( | |
| experiment_name, | |
| parameters={}, | |
| objectives={}, | |
| parameter_constraints={}, | |
| ax_save_path="", | |
| resume=False, | |
| **kwargs): | |
| """ | |
| Initializes AxClient from a JSON file if available, otherwise creates a new experiment. | |
| Note that if you load from disk, any parameters will be ignored. | |
| """ | |
| if resume and os.path.exists(ax_save_path): | |
| raise NotImplementedError("Resuming from disk is not yet supported.") | |
| try: | |
| ax_client = AxClient.load_from_json_file(ax_save_path) | |
| print(f"Successfully loaded AxClient state from {ax_save_path}.") | |
| return ax_client | |
| except Exception as e: | |
| print(f"Failed to load AxClient from {ax_save_path}: {e}. Initializing a new AxClient.") | |
| # If the file does not exist or loading failed, create a new AxClient instance | |
| ax_client = AxClient() | |
| ax_client.create_experiment( | |
| name=experiment_name, | |
| parameters=parameters, | |
| objectives=objectives, | |
| parameter_constraints=parameter_constraints, | |
| **kwargs | |
| ) | |
| print(f"Created a new experiment: {experiment_name}.") | |
| return ax_client | |
| class JobManager: | |
| def __init__(self, executor, ax_client, ax_save_path, jobs_state_path=None, wait_interval=60): | |
| self.executor = executor | |
| self.ax_client = ax_client | |
| self.ax_save_path = ax_save_path | |
| self.jobs_state_path = jobs_state_path or f"{ax_save_path}.jobs.pkl" | |
| self.jobs = {} # Track jobs by trial_index, storing (job, parameters) tuples | |
| self.save_lock = asyncio.Lock() | |
| self.wait_interval = wait_interval | |
| async def safe_save_state(self): | |
| async with self.save_lock: | |
| try: | |
| # Make sure the directory exists | |
| os.makedirs(os.path.dirname(self.ax_save_path), exist_ok=True) | |
| self.ax_client.save_to_json_file(self.ax_save_path) | |
| print(f"Successfully saved AxClient state to {self.ax_save_path}") | |
| except Exception as e: | |
| print(f"Failed to save AxClient state: {e}") | |
| try: | |
| with open(self.jobs_state_path, 'wb') as f: | |
| cloudpickle.dump(self.jobs, f) | |
| print(f"Successfully saved jobs state to {self.jobs_state_path}") | |
| except Exception as e: | |
| print(f"Failed to save jobs state: {e}") | |
| @staticmethod | |
| def load_state(executor, ax_client, ax_save_path, jobs_state_path=None, wait_interval=30): | |
| jobs_state_path = jobs_state_path or f"{ax_save_path}.jobs.pkl" | |
| if os.path.exists(jobs_state_path): | |
| try: | |
| with open(jobs_state_path, 'rb') as f: | |
| jobs = cloudpickle.load(f) | |
| print(f"Successfully loaded jobs state from {jobs_state_path}") | |
| job_manager = JobManager(executor, ax_client, ax_save_path, jobs_state_path, wait_interval=wait_interval) | |
| job_manager.jobs = jobs | |
| return job_manager | |
| except Exception as e: | |
| print(f"Failed to load jobs state: {e}") | |
| return JobManager(executor, ax_client, ax_save_path, jobs_state_path, wait_interval=wait_interval) | |
| async def process_job(self, fn, parameters, trial_index, is_new=True): | |
| if is_new: | |
| job = self.executor.submit(fn, parameters) | |
| self.jobs[trial_index] = (job, parameters) | |
| else: | |
| job, parameters = self.jobs[trial_index] | |
| # Reattach the trial to the Ax client if necessary | |
| trial = self.ax_client.experiment.trials.get(trial_index) | |
| if trial is None or trial.status.is_failed: | |
| # Reattach the trial and update trial_index if necessary | |
| parameters, new_trial_index = self.ax_client.attach_trial(parameters=parameters) | |
| if new_trial_index != trial_index: | |
| # Update the job's trial index | |
| del self.jobs[trial_index] | |
| self.jobs[new_trial_index] = (job, parameters) | |
| trial_index = new_trial_index | |
| try: | |
| print(f"Processing job {job.job_id} for trial {trial_index}") | |
| result = await job.awaitable().result() | |
| self.ax_client.complete_trial(trial_index=trial_index, raw_data=result) | |
| except Exception as e: | |
| job_stderr = str(e) | |
| self.ax_client.log_trial_failure(trial_index=trial_index, metadata={"stderr": job_stderr}) | |
| print(f"Trial {trial_index} failed with error:\n{job_stderr}") | |
| del self.jobs[trial_index] | |
| await self.safe_save_state() | |
| async def process_all_jobs(self, fn=None): | |
| tasks = [] | |
| for trial_index, (job, parameters) in list(self.jobs.items()): | |
| # Create a task to process each job | |
| task = asyncio.create_task(self.process_job(fn, parameters, trial_index, is_new=False)) | |
| tasks.append(task) | |
| await asyncio.gather(*tasks) | |
| async def run_trials( | |
| fn, executor, ax_client, trial_budget=25, | |
| ax_save_path="experiments/ax_state.json", | |
| job_manager_save_path='experiments/ax_state.jobs.pkl', wait_interval=30, | |
| ): | |
| job_manager = JobManager.load_state( | |
| executor, ax_client, ax_save_path, job_manager_save_path, wait_interval=wait_interval | |
| ) | |
| # Process all serialized jobs before starting new ones | |
| await job_manager.process_all_jobs(fn) | |
| tasks = [] | |
| trials_submitted = 0 | |
| while trials_submitted < trial_budget: | |
| try: | |
| parameters, trial_index = ax_client.get_next_trial() | |
| task = asyncio.create_task(job_manager.process_job(fn, parameters, trial_index, is_new=True)) | |
| tasks.append(task) | |
| trials_submitted += 1 | |
| except (MaxParallelismReachedException, DataRequiredError) as e: | |
| print(f"Waiting for jobs to complete due to: {type(e).__name__}") | |
| await asyncio.sleep(wait_interval) | |
| print(f"Total trials submitted: {trials_submitted}") | |
| # Wait for all tasks to complete | |
| await asyncio.gather(*tasks) | |
| ### | |
| # Example usage | |
| ### | |
| if __name__ == "__main__": | |
| import matplotlib.pyplot as plt | |
| # import subprocess | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() # take environment variables from `.env` file | |
| # Define the experiment parameters | |
| experiment_name = "optim_hartmann6_dev" | |
| # Define the output directory and ensure it exists | |
| output_dir = os.path.join('experiments', experiment_name) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Update paths to be relative to output_dir | |
| json_file_path = os.path.join(output_dir, f'{experiment_name}.json') | |
| ax_save_path = os.path.join(output_dir, 'ax_state.json') | |
| job_manager_save_path = os.path.join(output_dir, 'ax_state.jobs.pkl') | |
| executor_folder = os.path.join(output_dir, 'jobs') | |
| def model_trial(params): | |
| """ | |
| A wrapper function which takes the dictionary of parameters from Ax, | |
| runs the synthetic function hartmann6, and returns the results. | |
| """ | |
| x = np.array([params.get(f"x{i+1}") for i in range(6)]) | |
| time.sleep(5) | |
| np.random.seed(int(time.time())) | |
| if np.random.rand() < 0.1: | |
| raise ValueError("Randomly failed") | |
| return { | |
| "hartmann6": (hartmann6(x), 0.0), | |
| "l2norm": (np.linalg.norm(x), 0.0) | |
| } | |
| experiment_params = [ | |
| { | |
| "name": f"x{i+1}", | |
| "type": "range", | |
| "bounds": [0.0, 1.0] | |
| } for i in range(6) | |
| ] | |
| objectives = { | |
| "hartmann6": ObjectiveProperties(minimize=True) | |
| } | |
| ax_client = init_or_load_ax_client( | |
| experiment_name=experiment_name, | |
| json_file_path=json_file_path, | |
| parameters=experiment_params, | |
| objectives=objectives, | |
| tracking_metric_names=["l2norm"], | |
| resume=False, | |
| choose_generation_strategy_kwargs=dict( | |
| no_bayesian_optimization=False, | |
| max_parallelism_override=20, | |
| use_batch_trials=False, | |
| ) | |
| ) | |
| # Initialize the executor with the updated folder | |
| executor = AutoExecutor(folder=executor_folder) | |
| executor.update_parameters( | |
| timeout_min=119, | |
| slurm_account=os.getenv('SLURM_ACCOUNT'), | |
| slurm_array_parallelism=20, | |
| mem_gb=1, #more often 32 | |
| cpus_per_task=1, # more often 8 | |
| gpus_per_node=0, # more often 1 or 4 | |
| name=experiment_name, | |
| ) | |
| # Run the trials with updated paths | |
| asyncio.run(run_trials( | |
| model_trial, | |
| executor, | |
| ax_client, | |
| trial_budget=13, | |
| ax_save_path=ax_save_path, | |
| job_manager_save_path=job_manager_save_path, | |
| )) | |
| # Analyze and plot results | |
| best_parameters, values = ax_client.get_best_parameters() | |
| print(f"Best parameters: {best_parameters}, values: {values}") | |
| best_objectives = [] | |
| for trial in ax_client.experiment.trials.values(): | |
| if trial.status == TrialStatus.COMPLETED: | |
| best_objectives.append(trial.objective_mean) | |
| print(trial.objective_mean) | |
| else: | |
| print(trial.status) | |
| best_objectives = np.array([best_objectives]) | |
| plt.figure() | |
| plt.plot(np.minimum.accumulate(best_objectives, axis=1).T, marker="o") | |
| # Save the plot in the output directory | |
| plot_path = os.path.join(output_dir, 'optimization_results.png') | |
| plt.savefig(plot_path) | |
| print(f"Optimization results plot saved to {plot_path}") | |
| # Optionally display the plot using imgcat (if in iTerm2 and imgcat is available) | |
| # subprocess.run(f"imgcat {plot_path}", shell=True) |