|
#!/usr/bin/env python |
|
""" |
|
Asynchronous hyperparam search using [Ax](https://ax.dev/) and the submitit executor, to run on SLURM. |
|
|
|
Supports resumption of incomplete optimisations from disk, and incremental/partial optimisation, I think. |
|
""" |
|
|
|
import os |
|
import asyncio |
|
import time |
|
|
|
from submitit import AutoExecutor, LocalJob, DebugJob |
|
import cloudpickle |
|
|
|
from ax.service.ax_client import AxClient, ObjectiveProperties |
|
from ax.utils.measurement.synthetic_functions import hartmann6 |
|
from ax.exceptions.generation_strategy import MaxParallelismReachedException |
|
from ax.exceptions.core import DataRequiredError |
|
from ax.core.base_trial import TrialStatus |
|
|
|
import numpy as np |
|
|
|
|
|
def init_or_load_ax_client( |
|
experiment_name, |
|
parameters={}, |
|
objectives={}, |
|
parameter_constraints={}, |
|
ax_save_path="", |
|
resume=False, |
|
**kwargs): |
|
""" |
|
Initializes AxClient from a JSON file if available, otherwise creates a new experiment. |
|
Note that if you load from disk, any parameters will be ignored. |
|
""" |
|
if resume and os.path.exists(ax_save_path): |
|
raise NotImplementedError("Resuming from disk is not yet supported.") |
|
try: |
|
ax_client = AxClient.load_from_json_file(ax_save_path) |
|
print(f"Successfully loaded AxClient state from {ax_save_path}.") |
|
return ax_client |
|
except Exception as e: |
|
print(f"Failed to load AxClient from {ax_save_path}: {e}. Initializing a new AxClient.") |
|
|
|
# If the file does not exist or loading failed, create a new AxClient instance |
|
ax_client = AxClient() |
|
ax_client.create_experiment( |
|
name=experiment_name, |
|
parameters=parameters, |
|
objectives=objectives, |
|
parameter_constraints=parameter_constraints, |
|
**kwargs |
|
) |
|
print(f"Created a new experiment: {experiment_name}.") |
|
return ax_client |
|
|
|
class JobManager: |
|
def __init__(self, executor, ax_client, ax_save_path, jobs_state_path=None, wait_interval=60): |
|
self.executor = executor |
|
self.ax_client = ax_client |
|
self.ax_save_path = ax_save_path |
|
self.jobs_state_path = jobs_state_path or f"{ax_save_path}.jobs.pkl" |
|
self.jobs = {} # Track jobs by trial_index, storing (job, parameters) tuples |
|
self.save_lock = asyncio.Lock() |
|
self.wait_interval = wait_interval |
|
|
|
async def safe_save_state(self): |
|
async with self.save_lock: |
|
try: |
|
# Make sure the directory exists |
|
os.makedirs(os.path.dirname(self.ax_save_path), exist_ok=True) |
|
self.ax_client.save_to_json_file(self.ax_save_path) |
|
print(f"Successfully saved AxClient state to {self.ax_save_path}") |
|
except Exception as e: |
|
print(f"Failed to save AxClient state: {e}") |
|
|
|
try: |
|
with open(self.jobs_state_path, 'wb') as f: |
|
cloudpickle.dump(self.jobs, f) |
|
print(f"Successfully saved jobs state to {self.jobs_state_path}") |
|
except Exception as e: |
|
print(f"Failed to save jobs state: {e}") |
|
|
|
@staticmethod |
|
def load_state(executor, ax_client, ax_save_path, jobs_state_path=None, wait_interval=30): |
|
jobs_state_path = jobs_state_path or f"{ax_save_path}.jobs.pkl" |
|
if os.path.exists(jobs_state_path): |
|
try: |
|
with open(jobs_state_path, 'rb') as f: |
|
jobs = cloudpickle.load(f) |
|
print(f"Successfully loaded jobs state from {jobs_state_path}") |
|
job_manager = JobManager(executor, ax_client, ax_save_path, jobs_state_path, wait_interval=wait_interval) |
|
job_manager.jobs = jobs |
|
return job_manager |
|
except Exception as e: |
|
print(f"Failed to load jobs state: {e}") |
|
|
|
return JobManager(executor, ax_client, ax_save_path, jobs_state_path, wait_interval=wait_interval) |
|
|
|
async def process_job(self, fn, parameters, trial_index, is_new=True): |
|
if is_new: |
|
job = self.executor.submit(fn, parameters) |
|
self.jobs[trial_index] = (job, parameters) |
|
else: |
|
job, parameters = self.jobs[trial_index] |
|
|
|
# Reattach the trial to the Ax client if necessary |
|
trial = self.ax_client.experiment.trials.get(trial_index) |
|
if trial is None or trial.status.is_failed: |
|
# Reattach the trial and update trial_index if necessary |
|
parameters, new_trial_index = self.ax_client.attach_trial(parameters=parameters) |
|
if new_trial_index != trial_index: |
|
# Update the job's trial index |
|
del self.jobs[trial_index] |
|
self.jobs[new_trial_index] = (job, parameters) |
|
trial_index = new_trial_index |
|
|
|
try: |
|
print(f"Processing job {job.job_id} for trial {trial_index}") |
|
result = await job.awaitable().result() |
|
self.ax_client.complete_trial(trial_index=trial_index, raw_data=result) |
|
except Exception as e: |
|
job_stderr = str(e) |
|
self.ax_client.log_trial_failure(trial_index=trial_index, metadata={"stderr": job_stderr}) |
|
print(f"Trial {trial_index} failed with error:\n{job_stderr}") |
|
|
|
del self.jobs[trial_index] |
|
await self.safe_save_state() |
|
|
|
async def process_all_jobs(self, fn=None): |
|
tasks = [] |
|
for trial_index, (job, parameters) in list(self.jobs.items()): |
|
# Create a task to process each job |
|
task = asyncio.create_task(self.process_job(fn, parameters, trial_index, is_new=False)) |
|
tasks.append(task) |
|
await asyncio.gather(*tasks) |
|
|
|
async def run_trials( |
|
fn, executor, ax_client, trial_budget=25, |
|
ax_save_path="experiments/ax_state.json", |
|
job_manager_save_path='experiments/ax_state.jobs.pkl', wait_interval=30, |
|
): |
|
job_manager = JobManager.load_state( |
|
executor, ax_client, ax_save_path, job_manager_save_path, wait_interval=wait_interval |
|
) |
|
|
|
# Process all serialized jobs before starting new ones |
|
await job_manager.process_all_jobs(fn) |
|
|
|
tasks = [] |
|
trials_submitted = 0 |
|
while trials_submitted < trial_budget: |
|
try: |
|
parameters, trial_index = ax_client.get_next_trial() |
|
task = asyncio.create_task(job_manager.process_job(fn, parameters, trial_index, is_new=True)) |
|
tasks.append(task) |
|
trials_submitted += 1 |
|
except (MaxParallelismReachedException, DataRequiredError) as e: |
|
print(f"Waiting for jobs to complete due to: {type(e).__name__}") |
|
await asyncio.sleep(wait_interval) |
|
|
|
print(f"Total trials submitted: {trials_submitted}") |
|
|
|
# Wait for all tasks to complete |
|
await asyncio.gather(*tasks) |
|
|
|
### |
|
# Example usage |
|
### |
|
|
|
if __name__ == "__main__": |
|
import matplotlib.pyplot as plt |
|
# import subprocess |
|
import os |
|
from dotenv import load_dotenv |
|
load_dotenv() # take environment variables from `.env` file |
|
|
|
# Define the experiment parameters |
|
experiment_name = "optim_hartmann6_dev" |
|
|
|
# Define the output directory and ensure it exists |
|
output_dir = os.path.join('experiments', experiment_name) |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
# Update paths to be relative to output_dir |
|
json_file_path = os.path.join(output_dir, f'{experiment_name}.json') |
|
ax_save_path = os.path.join(output_dir, 'ax_state.json') |
|
job_manager_save_path = os.path.join(output_dir, 'ax_state.jobs.pkl') |
|
executor_folder = os.path.join(output_dir, 'jobs') |
|
|
|
def model_trial(params): |
|
""" |
|
A wrapper function which takes the dictionary of parameters from Ax, |
|
runs the synthetic function hartmann6, and returns the results. |
|
""" |
|
x = np.array([params.get(f"x{i+1}") for i in range(6)]) |
|
time.sleep(5) |
|
np.random.seed(int(time.time())) |
|
if np.random.rand() < 0.1: |
|
raise ValueError("Randomly failed") |
|
return { |
|
"hartmann6": (hartmann6(x), 0.0), |
|
"l2norm": (np.linalg.norm(x), 0.0) |
|
} |
|
|
|
experiment_params = [ |
|
{ |
|
"name": f"x{i+1}", |
|
"type": "range", |
|
"bounds": [0.0, 1.0] |
|
} for i in range(6) |
|
] |
|
objectives = { |
|
"hartmann6": ObjectiveProperties(minimize=True) |
|
} |
|
|
|
ax_client = init_or_load_ax_client( |
|
experiment_name=experiment_name, |
|
json_file_path=json_file_path, |
|
parameters=experiment_params, |
|
objectives=objectives, |
|
tracking_metric_names=["l2norm"], |
|
resume=False, |
|
choose_generation_strategy_kwargs=dict( |
|
no_bayesian_optimization=False, |
|
max_parallelism_override=20, |
|
use_batch_trials=False, |
|
) |
|
) |
|
|
|
# Initialize the executor with the updated folder |
|
executor = AutoExecutor(folder=executor_folder) |
|
executor.update_parameters( |
|
timeout_min=119, |
|
slurm_account=os.getenv('SLURM_ACCOUNT'), |
|
slurm_array_parallelism=20, |
|
mem_gb=1, #more often 32 |
|
cpus_per_task=1, # more often 8 |
|
gpus_per_node=0, # more often 1 or 4 |
|
name=experiment_name, |
|
) |
|
|
|
# Run the trials with updated paths |
|
asyncio.run(run_trials( |
|
model_trial, |
|
executor, |
|
ax_client, |
|
trial_budget=13, |
|
ax_save_path=ax_save_path, |
|
job_manager_save_path=job_manager_save_path, |
|
)) |
|
|
|
# Analyze and plot results |
|
best_parameters, values = ax_client.get_best_parameters() |
|
print(f"Best parameters: {best_parameters}, values: {values}") |
|
|
|
best_objectives = [] |
|
for trial in ax_client.experiment.trials.values(): |
|
if trial.status == TrialStatus.COMPLETED: |
|
best_objectives.append(trial.objective_mean) |
|
print(trial.objective_mean) |
|
else: |
|
print(trial.status) |
|
best_objectives = np.array([best_objectives]) |
|
|
|
plt.figure() |
|
plt.plot(np.minimum.accumulate(best_objectives, axis=1).T, marker="o") |
|
|
|
# Save the plot in the output directory |
|
plot_path = os.path.join(output_dir, 'optimization_results.png') |
|
plt.savefig(plot_path) |
|
print(f"Optimization results plot saved to {plot_path}") |
|
|
|
# Optionally display the plot using imgcat (if in iTerm2 and imgcat is available) |
|
# subprocess.run(f"imgcat {plot_path}", shell=True) |