Skip to content

Instantly share code, notes, and snippets.

@crypdick
Created March 11, 2025 17:30
Show Gist options
  • Save crypdick/555bcfe7cb68a4bdb057a78965bf7d68 to your computer and use it in GitHub Desktop.
Save crypdick/555bcfe7cb68a4bdb057a78965bf7d68 to your computer and use it in GitHub Desktop.
from time import sleep
import ray
from ray import tune
from ray.tune.tuner import Tuner
import time
def expensive_setup():
print("EXPENSIVE SETUP")
sleep(1)
class QuadraticTrainable(tune.Trainable):
def setup(self, config):
# Store the configuration containing hparam1 and hparam2
self.config = config
expensive_setup()
self.max_steps = 5
self.step_count = 0
def step(self):
# Extract hyperparameters from the config
h1 = self.config["hparam1"]
h2 = self.config["hparam2"]
# Compute a simple quadratic objective where the optimum is at hparam1=3 and hparam2=5
loss = (h1 - 3) ** 2 + (h2 - 5) ** 2
metrics = {"loss": loss}
self.step_count += 1
if self.step_count > self.max_steps:
metrics["done"] = True
# Return the computed loss as the metric
return metrics
def reset_config(self, new_config):
# Update the configuration for a new trial while reusing the actor
self.config = new_config
return True
if __name__ == "__main__":
ray.init()
search_space = {
"hparam1": tune.uniform(-10, 10),
"hparam2": tune.uniform(-10, 10)
}
# Common tuner parameters
tuner_kwargs = {
"trainable": QuadraticTrainable,
"param_space": search_space,
"run_config": ray.air.RunConfig(
verbose=0,
)
}
print("Running experiment with actor reuse")
start_time = time.time()
tuner_with_reuse = Tuner(
**tuner_kwargs,
tune_config=tune.TuneConfig(
num_samples=1,
max_concurrent_trials=1,
reuse_actors=True,
)
)
tuner_with_reuse.fit()
elapsed_reuse = time.time() - start_time
print("Running experiment without actor reuse")
start_time = time.time()
tuner_without_reuse = Tuner(
**tuner_kwargs,
tune_config=tune.TuneConfig(
num_samples=1,
max_concurrent_trials=1,
reuse_actors=False,
)
)
tuner_without_reuse.fit()
elapsed_no_reuse = time.time() - start_time
print("#" * 15)
print("Timing results")
print("-" * 15)
print(f"Timing with reuse_actors=True: {elapsed_reuse:.2f} seconds")
print(f"Timing without reuse_actors: {elapsed_no_reuse:.2f} seconds")
print("#" * 15)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment