Last active
May 19, 2025 03:50
-
-
Save ugo-nama-kun/ae371b359d8afa9428edd08e83d61478 to your computer and use it in GitHub Desktop.
IQM + 95 % bootstrap CI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from os import makedirs | |
import wandb | |
import pandas as pd | |
import seaborn as sns | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy import stats | |
# 保存ディレクトリ名を指定 (Exist_ok) | |
path_dir = "data_save_plot" | |
# ① APIに接続 | |
api = wandb.Api() | |
# ② プロジェクト名を設定 | |
project_path = "younr_name/marl_cpc" # 例: "user_name/cpc_project" | |
# ③ 抽出したい group と env_id を指定 | |
target_groups = ["cpc"] # 例:['cpc', 'shared', 'message', 'independent'] | |
target_envs = ["MyEnv"] # ここを任意の環境名に 例:['BoxPassEnv', 'AlertEnv'] (env_idのパラメータを定義して使用していたため) | |
performance_measures = [ # wandb のなんらかのメトリクス。メトリクスごとにIQM + 95% bootstrap CI を計算 | |
'test/average_reward/agent0', | |
'test/average_reward/agent1', | |
] | |
# ④ 全 runs を取得しフィルタリング | |
all_runs = api.runs(project_path) | |
for target_env in target_envs: | |
filtered_runs = [run for run in all_runs if run.group in target_groups and run.config.get("env_id") == target_env] | |
print(f"対象ラン数 @ {target_env}: {len(filtered_runs)}") | |
for performance_measure in performance_measures: | |
# ⑤ データフレームに整形 | |
dfs = [] | |
for run in filtered_runs: | |
history = run.history(keys=["global_step", performance_measure]) # step or time | |
print(f"ID: {run.id}") | |
history['group'] = run.group | |
history['run_id'] = run.id | |
dfs.append(history) | |
df_all = pd.concat(dfs, ignore_index=True) | |
s = "" | |
for tmp in performance_measure.split('/'): | |
s += "_" + tmp | |
# CSV を保存(毎回wandbからロードに時間がかかるため) | |
makedirs(path_dir, exist_ok=True) | |
df_all.to_csv(path_dir + f"/df_all_{target_env}_{s}.csv", index=False) | |
print(f"saved.: {target_env}:{performance_measure}") | |
# # ⑥ 平均と信頼区間で集計 | |
# summary = df_all.groupby(['group', 'global_step']).agg( | |
# mean_performance=(performance_measure, 'mean'), | |
# sem=(performance_measure, lambda x: stats.sem(x, nan_policy='omit')), # 標準誤差 | |
# n=(performance_measure, 'count') | |
# ).reset_index() | |
# | |
# # ⑦ 95%信頼区間を追加(t値を使う) | |
# summary['ci95'] = summary['sem'] * stats.t.ppf(0.975, summary['n'] - 1) | |
def iqm(x): | |
"""Interquartile mean (middle 50% average)""" | |
x = np.sort(x.dropna()) | |
q1, q3 = np.percentile(x, [25, 75]) | |
iqr_data = x[(x >= q1) & (x <= q3)] | |
return np.mean(iqr_data) | |
# bootstrap 回数も報告する。reliableは2000回だったのでそれで。 | |
def bootstrap_ci(data, func, n_bootstrap=2000, alpha=0.05): | |
"""Bootstrap confidence interval for arbitrary function (e.g., IQM)""" | |
data = data.dropna().values | |
estimates = [func(pd.Series(np.random.choice(data, size=len(data), replace=True))) for _ in range(n_bootstrap)] | |
lower = np.percentile(estimates, 100 * alpha / 2) | |
upper = np.percentile(estimates, 100 * (1 - alpha / 2)) | |
return lower, upper | |
iqm_rows = [] | |
for (group, step), sub_df in df_all.groupby(['group', 'global_step']): | |
scores = sub_df[performance_measure] | |
center = iqm(scores) | |
lower, upper = bootstrap_ci(scores, iqm) | |
iqm_rows.append({ | |
'group': group, | |
'global_step': step, | |
'iqm': center, | |
'ci_lower': lower, | |
'ci_upper': upper | |
}) | |
summary_iqm = pd.DataFrame(iqm_rows) | |
# ⑧ プロット | |
# plt.figure(figsize=(10, 6)) | |
# for group in target_groups: | |
# sub = summary[summary['group'] == group] | |
# plt.plot(sub['global_step'], sub['mean_performance'], label=group) | |
# plt.fill_between(sub['global_step'], sub['mean_performance'] - sub['ci95'], sub['mean_performance'] + sub['ci95'], alpha=0.2) | |
# | |
# plt.xlabel("Step") | |
# plt.ylabel(performance_measure) | |
# plt.title(f"Performance over Time (env_id = {target_env})") | |
# plt.legend() | |
# plt.tight_layout() | |
# plt.savefig(f"data_performance/performance_{target_env}.pdf") | |
# plt.show() | |
plt.figure(figsize=(10, 6)) | |
for group in target_groups: | |
sub = summary_iqm[summary_iqm['group'] == group] | |
plt.plot(sub['global_step'], sub['iqm'], label=group) | |
plt.fill_between(sub['global_step'], sub['ci_lower'], sub['ci_upper'], alpha=0.2) | |
plt.xlabel("Step") | |
plt.ylabel(f"IQM of {performance_measure}") | |
plt.title(f"Interquartile Mean over Time (env_id = {target_env})") | |
plt.legend() | |
plt.tight_layout() | |
s = "" | |
for tmp in performance_measure.split('/'): | |
s += "_" + tmp | |
plt.savefig(path_dir + f"/performance_{target_env}_{s}_iqm.pdf") | |
print("Done.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment