Last active
October 28, 2023 01:04
-
-
Save c-bata/0eed0dfb416a6994fa30fb23bb38d3ad to your computer and use it in GitHub Desktop.
W&B 東京Meetup #3 - Optunaを使ったHuman-in-the-loop最適化の紹介
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Requirements: | |
# $ pip install optuna optuna-dashboard[preferential] diffusers transformers accelerate scipy safetensors xformers | |
# | |
# Process A: Launch a process that suggest new params and generate images. | |
# $ python main.py | |
# | |
# Process B: Launch an Optuna Dashboard process. | |
# $ optuna-dashboard sqlite:///db.sqlite3 --artifact-dir ./artifact | |
import time | |
import optuna | |
from optuna.trial import TrialState | |
import os | |
from optuna_dashboard import set_objective_names, register_objective_form_widgets, ChoiceWidget, save_note | |
from optuna_dashboard import save_note | |
from optuna_dashboard.artifact import upload_artifact, get_artifact_path | |
from optuna_dashboard.artifact.file_system import FileSystemBackend | |
from diffusers import StableDiffusionImg2ImgPipeline | |
import torch | |
from PIL import Image | |
device = "cuda:0" | |
torch_dtype = torch.float16 | |
init_img = Image.open("./input.png") | |
init_img = init_img.resize((768, 768)) | |
init_img = init_img.convert("RGB") | |
# 画像(Artifact)のアップロード先の設定 - 今回はFileSystemBackendを用いて、 "artifact" ディレクトリ以下に保存。 | |
base_path = os.path.join(os.path.dirname(__file__), "artifact") | |
artifact_backend = FileSystemBackend(base_path=base_path) | |
def suggest_and_generate_image(study: optuna.Study, pipe: StableDiffusionImg2ImgPipeline): | |
# OptunaのTrialを生成し、パラメーターおよびプロンプトに含めるキーワードをサンプル | |
trial = study.ask() | |
guidance_scale = trial.suggest_float("guidance_scale", 1, 50) | |
strength = trial.suggest_float("strength", 0.70, 1.0) | |
num_inference_steps = trial.suggest_int("num_inference_steps", 5, 100) | |
prompts = ["a mascot character with two eyes and a mouth"] | |
prompts.append(trial.suggest_categorical("adjectives", ["cute", "funny", "memorable", "charming", "entertaining"])) | |
prompts.append(trial.suggest_categorical("style", ["anime", "photo", "painting", ""])) | |
prompts.append(trial.suggest_categorical("facial-expression", ["smiling", "frowning", "grinning", ""])) | |
negative_prompt = [] | |
negative_prompt.append(trial.suggest_categorical("negative-quality", ["unnatural", "low-quality", ""])) | |
negative_prompt.append(trial.suggest_categorical("negative-adjectives", ["dull", "boring", "unfriendly", ""])) | |
# img2imgの実行 | |
images = pipe( | |
", ".join(prompts), | |
negative_prompt=", ".join(negative_prompt), | |
generator=torch.Generator(device).manual_seed(0), | |
strength=strength, | |
image=init_img, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
num_images_per_prompt=3, | |
).images | |
# 画像アップロードとURLパスの取得 | |
url_paths = [] | |
for i, image in enumerate(images): | |
image_path = f"tmp/sample-{trial.number}-{i}.png" | |
image.save(image_path) | |
artifact_id = upload_artifact(artifact_backend, trial, image_path) | |
url_path = get_artifact_path(trial, artifact_id) | |
url_paths.append(url_path) | |
# ノート (Markdown Format) の保存 | |
# Optuna DashboardのTrial Listタブから、各Trialを選択した際に表示されます。 | |
note = f"""## Prompt | |
* Prompt: {", ".join(prompts)} | |
* Negative Prompt: {negative_prompt} | |
## Images | |
| Seed 0 | Seed 1 | Seed 2 | | |
| ------ | ------ | ------ | | |
| ![image0]({url_paths[0]}) | ![image1]({url_paths[1]}) | ![image2]({url_paths[2]}) | | |
""" | |
save_note(trial, note) | |
def main(): | |
study = optuna.create_study( | |
storage="sqlite:///db.sqlite3", | |
sampler=optuna.samplers.TPESampler(n_startup_trials=5, constant_liar=True, multivariate=True), | |
direction="minimize", | |
study_name="Generate Optuna-kun (img2img)", | |
load_if_exists=True | |
) | |
# Optuna v3.1.1以前のバージョンでは、Optunaの _CachedStorage と呼ばれるキャッシュのためのコンポーネントの問題により、 | |
# 本チュートリアルのコードが正しく動きません。v3.2で修正予定ですが、v3.1.1以前のバージョンをご使用中の方は次のコードを実行してください。 | |
if isinstance(study._storage, optuna.storages._CachedStorage): | |
study._storage = study._storage._backend | |
set_objective_names(study, ["Score"]) | |
register_objective_form_widgets(study, widgets=[ | |
ChoiceWidget( | |
choices=["Good 👍", "So-so👌", "Bad 👎"], | |
values=[-1, 0, 1], | |
description="Please input your score!", | |
), | |
]) | |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-1", | |
torch_dtype=torch_dtype, | |
) | |
pipe = pipe.to(device) | |
pipe.enable_xformers_memory_efficient_attention() | |
# Start Human-in-the-loop Optimization | |
n_batch = 8 | |
while True: | |
# 実行中のTrialが n_batch (8個) を下回ったら新しくパラメーターをサンプルして、画像を生成 | |
running_trials = study.get_trials(deepcopy=False, states=(TrialState.RUNNING,)) | |
if len(running_trials) >= n_batch: | |
time.sleep(2) | |
continue | |
suggest_and_generate_image(study, pipe) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
動作の様子
実行方法
推奨実行環境
本Exampleを動かす際は Linux / NVIDIA GPUの利用を推奨 しています。もしGPUがなく動かすのが大変という方は、Optuna Dashboard公式ドキュメントの下記チュートリアルをお試しください。こちらはCPUのみで簡単に動かすことができ、Human-in-the-loop最適化のためのコードの理解には十分です。
https://optuna-dashboard.readthedocs.io/en/latest/tutorials/hitl.html
依存関係のインストール
入力画像の用意
下記画像はOptunaのロゴに手足のようなものを足したものです。
input.png
というファイル名でmain.py
実行時のlocationと同じディレクトリ内に保存してください。 (※ 本画像の利用はデモの動作確認のみにとどめてください)実行
次の2つのコマンドをそれぞれ実行してください。
python main.py
optuna-dashboard sqlite:///db.sqlite3 --artifact-dir ./artifact
macOSをお使いの方
xformersのインストール
本プログラムの実行に必要な
xformers
はmacOS向けのwheelバイナリを配布していないため、macOSではsdist (ソース配布パッケージ) からビルドされます。その際、Appleが提供しているClangにおいてOpenMPがデフォルトでは利用出来ないことから、次のようなエラーにぶつかることがあります。この場合は
libomp
を別途インストールしたり、別のCコンパイラに切り替えることでインストールが可能です。例えば後者は次のようにできます。main.py
の変更macOSで動かしたい方は、
main.py
に次のpatchを適用してください。pipe = pipe.to(device) -pipe.enable_xformers_memory_efficient_attention()