Created
April 9, 2023 07:23
-
-
Save keisuke-umezawa/d26f2ad40f52d4436265e6f88b1df036 to your computer and use it in GitHub Desktop.
HITL optimization with optuna
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import textwrap | |
import time | |
from typing import NoReturn | |
import optuna | |
from PIL import Image | |
from optuna.trial import TrialState | |
from optuna_dashboard import ObjectiveChoiceWidget, save_note | |
from optuna_dashboard import register_objective_form_widgets | |
from optuna_dashboard import set_objective_names | |
from optuna_dashboard.artifact import get_artifact_path, upload_artifact | |
from optuna_dashboard.artifact.file_system import FileSystemBackend | |
url = "sqlite:///db.sqlite3" | |
storage = optuna.storages.RDBStorage(url=url) | |
artifact_path = os.path.join(os.path.dirname(__file__), "artifact") | |
tmp_path = os.path.join(os.path.dirname(__file__), "tmp") | |
artifact_backend = FileSystemBackend(base_path=artifact_path) | |
def suggest_and_generate_image(study: optuna.Study) -> None: | |
# Ask new parameters | |
trial = study.ask() | |
r = trial.suggest_int("r", 0, 255) | |
g = trial.suggest_int("g", 0, 255) | |
b = trial.suggest_int("b", 0, 255) | |
# Generate image | |
image_path = f"tmp/sample-{trial.number}.png" | |
image = Image.new("RGB", (320, 240), color=(r, g, b)) | |
image.save(image_path) | |
# Upload Artifact | |
artifact_id = upload_artifact(artifact_backend, trial, image_path) | |
artifact_path = get_artifact_path(trial, artifact_id) | |
# Save Note | |
note = textwrap.dedent( | |
f"""\ | |
## Trial {trial.number} | |
 | |
""" | |
) | |
save_note(trial, note) | |
def start_preferential_optimization() -> NoReturn: | |
# Create Study | |
seed = 42 | |
sampler = optuna.samplers.TPESampler(constant_liar=True, seed=seed) | |
study = optuna.create_study( | |
study_name="Preferential Optimization", | |
storage=storage, | |
sampler=sampler, | |
load_if_exists=True, | |
) | |
orig_storage = study._storage | |
if isinstance(orig_storage, optuna.storages._cached_storage._CachedStorage): | |
study._storage = orig_storage._backend | |
set_objective_names(study, ["Looks like sunset color?"]) | |
register_objective_form_widgets( | |
study, | |
widgets=[ | |
ObjectiveChoiceWidget( | |
choices=["Good 👍", "So-so👌", "Bad 👎"], | |
values=[-1, 0, 1], | |
description="Please input your score!", | |
), | |
], | |
) | |
# Start Preferential Optimization | |
n_batch = 8 | |
while True: | |
running_trials = study.get_trials(deepcopy=False, states=(TrialState.RUNNING,)) | |
if len(running_trials) >= n_batch: | |
print("sleep") | |
time.sleep(1) | |
continue | |
suggest_and_generate_image(study) | |
def main() -> None: | |
if not os.path.exists(artifact_path): | |
os.mkdir(artifact_path) | |
if not os.path.exists(tmp_path): | |
os.mkdir(tmp_path) | |
# Run optimize loop | |
start_preferential_optimization() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment