Created
May 30, 2025 12:04
-
-
Save cutecutecat/353d29aefebaad17c648805e48dcdcd0 to your computer and use it in GitHub Desktop.
filter draw plot script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from matplotlib import ticker | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
QPS = [13.42, 12.32, 11.23, 8.75, 27.65, 25.36, 20.41, 16.66] | |
original_data = { | |
"acc": [ | |
0.9265, | |
0.9323, | |
0.9533, | |
0.9653, | |
0.9265, | |
0.9323, | |
0.9534, | |
0.9652 | |
], | |
"QPS": QPS, | |
"": ["baseline"] * 4 | |
+ ["prefilter enabled"] * 4 | |
} | |
df = pd.DataFrame(original_data) | |
categories = df[""].unique() | |
df_curve = pd.DataFrame() | |
for category in categories: | |
category_data = df[df[""] == category] | |
acc_min = min(category_data["acc"]) | |
acc_max = max(category_data["acc"]) | |
acc_values = np.linspace(acc_min, acc_max, 100) | |
rps_curve = np.interp(acc_values, category_data["acc"], category_data["QPS"]) | |
df_temp = pd.DataFrame({"acc": acc_values, "QPS": rps_curve, "": category}) | |
df_curve = pd.concat([df_curve, df_temp], ignore_index=True) | |
# 设置绘图风格和字体 | |
plt.style.use("seaborn-v0_8-whitegrid") | |
plt.rcParams["font.family"] = "serif" | |
plt.rcParams["font.size"] = 8 | |
# 创建图形 | |
fig, ax = plt.subplots(figsize=(5, 3)) | |
# 设置背景颜色 | |
ax.set_facecolor("#f0f7f4") | |
fig.set_facecolor("#f0f7f4") | |
# 绘制平滑曲线 | |
sns.lineplot( | |
x="acc", | |
y="QPS", | |
hue="", | |
data=df_curve, | |
dashes=False, | |
linewidth=1.5, | |
ax=ax, | |
palette=["#BF7E04", "#2ca02c"], # 指定颜色 | |
legend=False, | |
) | |
# 绘制原始数据点 | |
sns.scatterplot( | |
x="acc", | |
y="QPS", | |
hue="", | |
data=df, | |
style="", | |
markers=["o", "o"], | |
s=50, | |
linewidth=1.5, | |
ax=ax, | |
palette=["#BF7E04", "#2ca02c"], # 指定颜色 | |
legend=True, | |
) | |
# 调整图例 | |
# 调整图例 | |
# 添加标题和标签 | |
ax.set_title( | |
"LAION-5m Query Per Second (Filter rate 0.01)", | |
fontsize=12, | |
fontweight="bold", | |
pad=20, | |
) | |
ax.set_xlabel("Precision@Top 100", fontsize=10) | |
ax.set_ylabel("QPS", fontsize=10) | |
# 调整网格线 | |
ax.grid(True, linestyle="--", alpha=0.6) | |
# 保存图形 | |
plt.savefig("filter.png", dpi=300, bbox_inches="tight") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment