Skip to content

Instantly share code, notes, and snippets.

@ugo-nama-kun
Last active May 19, 2025 03:50
Show Gist options
  • Save ugo-nama-kun/ae371b359d8afa9428edd08e83d61478 to your computer and use it in GitHub Desktop.
Save ugo-nama-kun/ae371b359d8afa9428edd08e83d61478 to your computer and use it in GitHub Desktop.
IQM + 95 % bootstrap CI
from os import makedirs
import wandb
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
# 保存ディレクトリ名を指定 (Exist_ok)
path_dir = "data_save_plot"
# ① APIに接続
api = wandb.Api()
# ② プロジェクト名を設定
project_path = "younr_name/marl_cpc" # 例: "user_name/cpc_project"
# ③ 抽出したい group と env_id を指定
target_groups = ["cpc"] # 例:['cpc', 'shared', 'message', 'independent']
target_envs = ["MyEnv"] # ここを任意の環境名に 例:['BoxPassEnv', 'AlertEnv'] (env_idのパラメータを定義して使用していたため)
performance_measures = [ # wandb のなんらかのメトリクス。メトリクスごとにIQM + 95% bootstrap CI を計算
'test/average_reward/agent0',
'test/average_reward/agent1',
]
# ④ 全 runs を取得しフィルタリング
all_runs = api.runs(project_path)
for target_env in target_envs:
filtered_runs = [run for run in all_runs if run.group in target_groups and run.config.get("env_id") == target_env]
print(f"対象ラン数 @ {target_env}: {len(filtered_runs)}")
for performance_measure in performance_measures:
# ⑤ データフレームに整形
dfs = []
for run in filtered_runs:
history = run.history(keys=["global_step", performance_measure]) # step or time
print(f"ID: {run.id}")
history['group'] = run.group
history['run_id'] = run.id
dfs.append(history)
df_all = pd.concat(dfs, ignore_index=True)
s = ""
for tmp in performance_measure.split('/'):
s += "_" + tmp
# CSV を保存(毎回wandbからロードに時間がかかるため)
makedirs(path_dir, exist_ok=True)
df_all.to_csv(path_dir + f"/df_all_{target_env}_{s}.csv", index=False)
print(f"saved.: {target_env}:{performance_measure}")
# # ⑥ 平均と信頼区間で集計
# summary = df_all.groupby(['group', 'global_step']).agg(
# mean_performance=(performance_measure, 'mean'),
# sem=(performance_measure, lambda x: stats.sem(x, nan_policy='omit')), # 標準誤差
# n=(performance_measure, 'count')
# ).reset_index()
#
# # ⑦ 95%信頼区間を追加(t値を使う)
# summary['ci95'] = summary['sem'] * stats.t.ppf(0.975, summary['n'] - 1)
def iqm(x):
"""Interquartile mean (middle 50% average)"""
x = np.sort(x.dropna())
q1, q3 = np.percentile(x, [25, 75])
iqr_data = x[(x >= q1) & (x <= q3)]
return np.mean(iqr_data)
# bootstrap 回数も報告する。reliableは2000回だったのでそれで。
def bootstrap_ci(data, func, n_bootstrap=2000, alpha=0.05):
"""Bootstrap confidence interval for arbitrary function (e.g., IQM)"""
data = data.dropna().values
estimates = [func(pd.Series(np.random.choice(data, size=len(data), replace=True))) for _ in range(n_bootstrap)]
lower = np.percentile(estimates, 100 * alpha / 2)
upper = np.percentile(estimates, 100 * (1 - alpha / 2))
return lower, upper
iqm_rows = []
for (group, step), sub_df in df_all.groupby(['group', 'global_step']):
scores = sub_df[performance_measure]
center = iqm(scores)
lower, upper = bootstrap_ci(scores, iqm)
iqm_rows.append({
'group': group,
'global_step': step,
'iqm': center,
'ci_lower': lower,
'ci_upper': upper
})
summary_iqm = pd.DataFrame(iqm_rows)
# ⑧ プロット
# plt.figure(figsize=(10, 6))
# for group in target_groups:
# sub = summary[summary['group'] == group]
# plt.plot(sub['global_step'], sub['mean_performance'], label=group)
# plt.fill_between(sub['global_step'], sub['mean_performance'] - sub['ci95'], sub['mean_performance'] + sub['ci95'], alpha=0.2)
#
# plt.xlabel("Step")
# plt.ylabel(performance_measure)
# plt.title(f"Performance over Time (env_id = {target_env})")
# plt.legend()
# plt.tight_layout()
# plt.savefig(f"data_performance/performance_{target_env}.pdf")
# plt.show()
plt.figure(figsize=(10, 6))
for group in target_groups:
sub = summary_iqm[summary_iqm['group'] == group]
plt.plot(sub['global_step'], sub['iqm'], label=group)
plt.fill_between(sub['global_step'], sub['ci_lower'], sub['ci_upper'], alpha=0.2)
plt.xlabel("Step")
plt.ylabel(f"IQM of {performance_measure}")
plt.title(f"Interquartile Mean over Time (env_id = {target_env})")
plt.legend()
plt.tight_layout()
s = ""
for tmp in performance_measure.split('/'):
s += "_" + tmp
plt.savefig(path_dir + f"/performance_{target_env}_{s}_iqm.pdf")
print("Done.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment