Created
December 5, 2022 12:48
-
-
Save chenyaofo/6b3be5b88b7d766c52bac0a69011285b to your computer and use it in GitHub Desktop.
Scientific Figure Codebases.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from typing import List, Tuple | |
import matplotlib.pyplot as plt | |
from matplotlib import ticker | |
from matplotlib import font_manager | |
# more details about matplotlib usage at | |
# https://wizardforcel.gitbooks.io/matplotlib-user-guide/content/ | |
def color(*items): | |
return [item / 255 for item in items] | |
# All colors in the followings are from | |
# https://sci-figure-colors.unvs.cc/%E7%A7%91%E7%A0%94%E8%AE%BA%E6%96%87%E6%8F%92%E5%9B%BE%E9%85%8D%E8%89%B2.pdf | |
COLORS = {} | |
# No.002 | |
COLORS[2] = { | |
"red": color(223, 122, 94), | |
"blue": color(60, 64, 91), | |
"green": color(130, 178, 194), | |
"yellow": color(240, 210, 134) | |
} | |
# No.006 | |
COLORS[6] = { | |
"red": color(230, 111, 81), | |
"blue": color(38, 70, 83), | |
"light_green": color(138, 176, 125), | |
"green": color(41, 157, 143), | |
"dark_green": color(40, 114, 113), | |
"yellow": color(232, 197, 107), | |
"orange": color(243, 162, 97) | |
} | |
# No.007 | |
COLORS[7] = { | |
"dark_red": color(168, 3, 38), | |
"red": color(236, 93, 59), | |
"orange": color(253, 185, 107), | |
"dark_blue": color(57, 81, 162), | |
"blue": color(114, 170, 207), | |
"light_blue": color(202, 232, 242), | |
} | |
def show_available_font(): | |
# more details at https://cloud.tencent.com/developer/article/1761532 | |
print("All available font in matplotlib are as follows:") | |
for font in font_manager.fontManager.ttflist: | |
print(font.name, '-', font.fname) | |
def set_global_matplotlib_font(fontfamily: str): | |
plt.rcParams['font.sans-serif'] = fontfamily | |
def add_temp_font_in_matplotlib(fontpath): | |
font_manager.fontManager.addfont(fontpath) | |
def to_percent(value, position): | |
return f"{value*100:.0f}%" | |
def plot_line( | |
Xs: List[List[float]], | |
Ys: List[List[float]], | |
labels: List[str], | |
x_labels: List[str], | |
y_labels: List[str], | |
colors: List[Tuple[float]], | |
title: str, | |
subtitles: List[str], | |
n_rows: int = 1, | |
xlims: List[Tuple[float]] = None, | |
ylims: List[Tuple[float]] = None, | |
figsize: List[int] = None, | |
grid: bool = False, | |
top_right_visible: bool = True, | |
linewidth: int = 3, | |
fontsize: int = 20, | |
yaxis_formatters=None, | |
save_path: str = "figure.png", | |
): | |
assert len(Xs) == len(Ys) | |
assert len(Xs) == len(colors) | |
assert len(Xs[0]) == len(colors[0]) | |
if labels is not None: | |
assert len(Xs[0]) == len(labels) | |
else: | |
labels = [None] * len(Xs[0]) | |
if xlims is None: | |
xlims = [None] * len(Xs) | |
if ylims is None: | |
ylims = [None] * len(Xs) | |
if yaxis_formatters is None: | |
yaxis_formatters = [None] * len(Xs) | |
assert len(Xs) % n_rows == 0 | |
n_columns = len(Xs) // n_rows | |
if figsize is None: | |
figsize = (n_columns*9, n_rows*5) | |
fig = plt.figure(figsize=figsize) | |
for i, (_X, _Y, x_label, y_label, xlim, ylim, subtitle, yaxis_formatter, subcolors) in \ | |
enumerate(zip(Xs, Ys, x_labels, y_labels, xlims, ylims, subtitles, yaxis_formatters, colors), start=1): | |
ax = fig.add_subplot(n_rows, n_columns, i) | |
for j, (x, y, label, color) in \ | |
enumerate(zip(_X, _Y, labels, subcolors)): | |
ax.plot(x, y, color=color, linestyle='-', linewidth=linewidth, label=label) | |
ax.tick_params(axis='both', which='major', labelsize=fontsize) | |
ax.tick_params(axis='both', which='minor', labelsize=fontsize) | |
ax.set_xlabel(x_label, fontsize=fontsize) | |
ax.set_ylabel(y_label, fontsize=fontsize) | |
ax.set_title(subtitle, fontsize=fontsize) | |
ax.set_xlim(xlim) | |
ax.set_ylim(ylim) | |
ax.legend() | |
if yaxis_formatter: | |
ax.yaxis.set_major_formatter(ticker.FuncFormatter(yaxis_formatter)) | |
if grid: | |
ax.grid(color='gray', linestyle='--', linewidth=0.5) | |
if not top_right_visible: | |
ax.spines['top'].set_visible(False) | |
ax.spines['right'].set_visible(False) | |
ax.legend(fontsize=fontsize) | |
fig.suptitle(title, fontsize=fontsize+10) | |
fig.tight_layout() | |
if save_path.endswith("pdf"): | |
fig.savefig(save_path, format="pdf", dpi=600) | |
elif save_path.endswith("svg"): | |
fig.savefig(save_path, format="svg", dpi=600) | |
else: | |
fig.savefig(save_path) | |
baseline = torch.load("baseline_metrics.pt") | |
oursv1 = torch.load("our_metrics.pt") | |
oursv2 = torch.load("oursv2_metrics.pt") | |
x = list(range(1, 91)) | |
show_available_font() | |
set_global_matplotlib_font("Times New Roman") | |
plot_line( | |
Xs=[[x]*3]*4, | |
Ys=[ | |
[baseline["train/loss"], oursv1["train/loss"], oursv2["train/loss"]], | |
[baseline["eval/loss"], oursv1["eval/loss"], oursv2["eval/loss"]], | |
[baseline["train/top1_acc"], oursv1["train/top1_acc"], oursv2["train/top1_acc"]], | |
[baseline["eval/top1_acc"], oursv1["eval/top1_acc"], oursv2["eval/top1_acc"]] | |
], | |
labels=["baseline", "Ours-A", "Ours-B"], | |
x_labels=["Epoch"]*4, | |
y_labels=["Loss", "Loss", 'Acc. (%)', 'Acc. (%)'], | |
xlims=[[0, 90]]*4, | |
ylims=[[0, 6], [0, 6], [0, 1], [0, 1]], | |
colors=[[COLORS[6]["dark_green"], COLORS[7]["dark_blue"], COLORS[7]["dark_red"]]]*4, | |
title=f"Evolution of Loss/Accuracy", | |
subtitles=["Training Loss", "Eval Loss", "Training Top-1 Acc.", "Eval Top-1 Acc."], | |
n_rows=2, | |
grid=True, | |
top_right_visible=False, | |
yaxis_formatters=[None, None, to_percent, to_percent] | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment