Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Created December 5, 2022 12:48
Show Gist options
  • Save chenyaofo/6b3be5b88b7d766c52bac0a69011285b to your computer and use it in GitHub Desktop.
Save chenyaofo/6b3be5b88b7d766c52bac0a69011285b to your computer and use it in GitHub Desktop.
Scientific Figure Codebases.
import torch
from typing import List, Tuple
import matplotlib.pyplot as plt
from matplotlib import ticker
from matplotlib import font_manager
# more details about matplotlib usage at
# https://wizardforcel.gitbooks.io/matplotlib-user-guide/content/
def color(*items):
return [item / 255 for item in items]
# All colors in the followings are from
# https://sci-figure-colors.unvs.cc/%E7%A7%91%E7%A0%94%E8%AE%BA%E6%96%87%E6%8F%92%E5%9B%BE%E9%85%8D%E8%89%B2.pdf
COLORS = {}
# No.002
COLORS[2] = {
"red": color(223, 122, 94),
"blue": color(60, 64, 91),
"green": color(130, 178, 194),
"yellow": color(240, 210, 134)
}
# No.006
COLORS[6] = {
"red": color(230, 111, 81),
"blue": color(38, 70, 83),
"light_green": color(138, 176, 125),
"green": color(41, 157, 143),
"dark_green": color(40, 114, 113),
"yellow": color(232, 197, 107),
"orange": color(243, 162, 97)
}
# No.007
COLORS[7] = {
"dark_red": color(168, 3, 38),
"red": color(236, 93, 59),
"orange": color(253, 185, 107),
"dark_blue": color(57, 81, 162),
"blue": color(114, 170, 207),
"light_blue": color(202, 232, 242),
}
def show_available_font():
# more details at https://cloud.tencent.com/developer/article/1761532
print("All available font in matplotlib are as follows:")
for font in font_manager.fontManager.ttflist:
print(font.name, '-', font.fname)
def set_global_matplotlib_font(fontfamily: str):
plt.rcParams['font.sans-serif'] = fontfamily
def add_temp_font_in_matplotlib(fontpath):
font_manager.fontManager.addfont(fontpath)
def to_percent(value, position):
return f"{value*100:.0f}%"
def plot_line(
Xs: List[List[float]],
Ys: List[List[float]],
labels: List[str],
x_labels: List[str],
y_labels: List[str],
colors: List[Tuple[float]],
title: str,
subtitles: List[str],
n_rows: int = 1,
xlims: List[Tuple[float]] = None,
ylims: List[Tuple[float]] = None,
figsize: List[int] = None,
grid: bool = False,
top_right_visible: bool = True,
linewidth: int = 3,
fontsize: int = 20,
yaxis_formatters=None,
save_path: str = "figure.png",
):
assert len(Xs) == len(Ys)
assert len(Xs) == len(colors)
assert len(Xs[0]) == len(colors[0])
if labels is not None:
assert len(Xs[0]) == len(labels)
else:
labels = [None] * len(Xs[0])
if xlims is None:
xlims = [None] * len(Xs)
if ylims is None:
ylims = [None] * len(Xs)
if yaxis_formatters is None:
yaxis_formatters = [None] * len(Xs)
assert len(Xs) % n_rows == 0
n_columns = len(Xs) // n_rows
if figsize is None:
figsize = (n_columns*9, n_rows*5)
fig = plt.figure(figsize=figsize)
for i, (_X, _Y, x_label, y_label, xlim, ylim, subtitle, yaxis_formatter, subcolors) in \
enumerate(zip(Xs, Ys, x_labels, y_labels, xlims, ylims, subtitles, yaxis_formatters, colors), start=1):
ax = fig.add_subplot(n_rows, n_columns, i)
for j, (x, y, label, color) in \
enumerate(zip(_X, _Y, labels, subcolors)):
ax.plot(x, y, color=color, linestyle='-', linewidth=linewidth, label=label)
ax.tick_params(axis='both', which='major', labelsize=fontsize)
ax.tick_params(axis='both', which='minor', labelsize=fontsize)
ax.set_xlabel(x_label, fontsize=fontsize)
ax.set_ylabel(y_label, fontsize=fontsize)
ax.set_title(subtitle, fontsize=fontsize)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.legend()
if yaxis_formatter:
ax.yaxis.set_major_formatter(ticker.FuncFormatter(yaxis_formatter))
if grid:
ax.grid(color='gray', linestyle='--', linewidth=0.5)
if not top_right_visible:
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.legend(fontsize=fontsize)
fig.suptitle(title, fontsize=fontsize+10)
fig.tight_layout()
if save_path.endswith("pdf"):
fig.savefig(save_path, format="pdf", dpi=600)
elif save_path.endswith("svg"):
fig.savefig(save_path, format="svg", dpi=600)
else:
fig.savefig(save_path)
baseline = torch.load("baseline_metrics.pt")
oursv1 = torch.load("our_metrics.pt")
oursv2 = torch.load("oursv2_metrics.pt")
x = list(range(1, 91))
show_available_font()
set_global_matplotlib_font("Times New Roman")
plot_line(
Xs=[[x]*3]*4,
Ys=[
[baseline["train/loss"], oursv1["train/loss"], oursv2["train/loss"]],
[baseline["eval/loss"], oursv1["eval/loss"], oursv2["eval/loss"]],
[baseline["train/top1_acc"], oursv1["train/top1_acc"], oursv2["train/top1_acc"]],
[baseline["eval/top1_acc"], oursv1["eval/top1_acc"], oursv2["eval/top1_acc"]]
],
labels=["baseline", "Ours-A", "Ours-B"],
x_labels=["Epoch"]*4,
y_labels=["Loss", "Loss", 'Acc. (%)', 'Acc. (%)'],
xlims=[[0, 90]]*4,
ylims=[[0, 6], [0, 6], [0, 1], [0, 1]],
colors=[[COLORS[6]["dark_green"], COLORS[7]["dark_blue"], COLORS[7]["dark_red"]]]*4,
title=f"Evolution of Loss/Accuracy",
subtitles=["Training Loss", "Eval Loss", "Training Top-1 Acc.", "Eval Top-1 Acc."],
n_rows=2,
grid=True,
top_right_visible=False,
yaxis_formatters=[None, None, to_percent, to_percent]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment