Skip to content

Instantly share code, notes, and snippets.

@c-bata
Created July 21, 2023 05:47
Show Gist options
  • Save c-bata/ffcf52ca5c6fee5ec64e56eee451aabd to your computer and use it in GitHub Desktop.
Save c-bata/ffcf52ca5c6fee5ec64e56eee451aabd to your computer and use it in GitHub Desktop.
from __future__ import annotations
import os
import shutil
import uuid
import tempfile
from typing import NoReturn
import optuna
import streamlit as st
from optuna.trial import FrozenTrial, TrialState
from optuna_dashboard.artifact.file_system import FileSystemBackend
from optuna_dashboard._form_widget import get_form_widgets_json
from optuna_dashboard._note import get_note_from_system_attrs
from optuna_dashboard.streamlit import render_trial_note
from optuna_dashboard.streamlit import render_objective_form_widgets
artifact_path = os.path.join(os.path.dirname(__file__), "artifact")
artifact_backend = FileSystemBackend(base_path=artifact_path)
def get_tmp_dir() -> str:
if "tmp_dir" not in st.session_state:
tmp_dir_name = str(uuid.uuid4())
tmp_dir_path = os.path.join(tempfile.gettempdir(), tmp_dir_name)
os.makedirs(tmp_dir_path, exist_ok=True)
st.session_state.tmp_dir = tmp_dir_path
return st.session_state.tmp_dir
def start_streamlit() -> None:
tmpdir = get_tmp_dir()
study = optuna.load_study(
storage="sqlite:///streamlit-db.sqlite3", study_name="Human-in-the-loop Optimization"
)
selected_trial = st.sidebar.selectbox(
"一覧", study.trials, format_func=lambda t: t.number
)
render_trial_note(study, selected_trial)
artifact_id = selected_trial.user_attrs.get("artifact_id")
if artifact_id:
with artifact_backend.open(artifact_id) as fsrc:
tmp_img_path = os.path.join(tmpdir, artifact_id + ".png")
with open(tmp_img_path, "wb") as fdst:
shutil.copyfileobj(fsrc, fdst)
st.image(tmp_img_path, caption="Image")
if selected_trial.state == TrialState.RUNNING:
render_objective_form_widgets(study, selected_trial)
if __name__ == "__main__":
start_streamlit()
import os
import time
import tempfile
from typing import NoReturn
import optuna
from optuna.trial import TrialState
from optuna_dashboard import ChoiceWidget
from optuna_dashboard import register_objective_form_widgets
from optuna_dashboard import save_note
from optuna_dashboard.artifact import get_artifact_path
from optuna_dashboard.artifact import upload_artifact
from optuna_dashboard.artifact.file_system import FileSystemBackend
from optuna_dashboard.streamlit import render_trial_note
from optuna_dashboard.streamlit import render_objective_form_widgets
from PIL import Image
def suggest_and_generate_image(study: optuna.Study, artifact_backend: FileSystemBackend, tmpdir: str) -> None:
# 1. Ask new parameters
trial = study.ask()
r = trial.suggest_int("r", 0, 255)
g = trial.suggest_int("g", 0, 255)
b = trial.suggest_int("b", 0, 255)
# 2. Generate image
image_path = os.path.join(tmpdir, f"sample-{trial.number}.png")
image = Image.new("RGB", (320, 240), color=(r, g, b))
image.save(image_path)
# 3. Upload Artifact
artifact_id = upload_artifact(artifact_backend, trial, image_path)
trial.set_user_attr("artifact_id", artifact_id)
# 4. Save Note
save_note(trial, f"## Trial {trial.number}")
def main() -> NoReturn:
# 1. Create Artifact Store
artifact_path = os.path.join(os.path.dirname(__file__), "artifact")
artifact_backend = FileSystemBackend(base_path=artifact_path)
if not os.path.exists(artifact_path):
os.mkdir(artifact_path)
# 2. Create Study
study = optuna.create_study(
study_name="Human-in-the-loop Optimization",
storage="sqlite:///streamlit-db.sqlite3",
sampler=optuna.samplers.TPESampler(constant_liar=True, n_startup_trials=5),
load_if_exists=True,
)
study.set_metric_names(["Looks like sunset color?"])
# 4. Register ChoiceWidget
register_objective_form_widgets(
study,
widgets=[
ChoiceWidget(
choices=["Good 👍", "So-so👌", "Bad 👎"],
values=[-1, 0, 1],
description="Please input your score!",
),
],
)
# 5. Start Human-in-the-loop Optimization
n_batch = 4
with tempfile.TemporaryDirectory() as tmpdir:
while True:
running_trials = study.get_trials(deepcopy=False, states=(TrialState.RUNNING,))
if len(running_trials) >= n_batch:
time.sleep(1) # Avoid busy-loop
continue
suggest_and_generate_image(study, artifact_backend, tmpdir)
if __name__ == "__main__":
main()
@c-bata
Copy link
Author

c-bata commented Jul 21, 2023

$ python generateor.py
$ streamlit run evaluator.py

Screenshot 2023-07-21 14 46 36

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment