Created
May 10, 2021 12:21
-
-
Save Ethan00Si/86db6a87528f3caad0cee8bec3e7819b to your computer and use it in GitHub Desktop.
adjust order of legend using matplotlib.pyplot
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
fig, ax = plt.subplots() | |
data_list = [] | |
w_list = [3,4,2,1,5] | |
for i in w_list: | |
x = [x for x in range(0,100,10)] | |
y = [i*x for x in range(0, 100, 10)] | |
data_list.append((x,y)) | |
# plot results in a given order | |
mapping_order = {"w0":0, "w1":1, "w2":2, "w3":3, "w4":4, "w5":5, "w6":6} | |
marker = ['-,', '-o', '-^', '-v', '-s', '-p', '-*', '-h', '-+', '-x', '-1', '-2', '-3', '-4'] | |
colors = ['b', 'g', 'r', 'c', 'orangered', 'rosybrown', 'black', 'crimson', 'navy', 'chocolate', 'maroon'] | |
for i in range(len(data_list)): | |
x,y = data_list[i] | |
tmp = {mapping_order[x]:x for x in mapping_order} | |
file_name = 'w'+str(w_list[i]) | |
order = -1 | |
''' | |
注释处理根据读入数据的文件名排序画图顺序的情况 | |
''' | |
for item in mapping_order: | |
if item[1] == str(w_list[i]): | |
order = mapping_order[item] | |
break | |
marker_control = marker[order] | |
color_control = colors[order] | |
ax.plot(x, y, marker_control,c=color_control, label=file_name, linewidth=1, ms=8) | |
# plt.plot(range(1, 319, 10), X[i][0::10], c=colors[i], marker=marker[i], label=file_names[i], ls='--', lw=1, ms=5) | |
''' | |
调整legend中的顺序 | |
下面是假设legend中的顺序是乱序的,需要调整为mapping order的顺序 | |
''' | |
handles, labels = ax.get_legend_handles_labels() | |
sorted_handles, sorted_labels = [0]*len(mapping_order), [0]*len(mapping_order) | |
for i in range(len(labels)): | |
order = -1 | |
for item in mapping_order: | |
if item == labels[i][:len(item)]: | |
order = mapping_order[item] | |
break | |
sorted_labels[order] = labels[i] | |
sorted_handles[order] = handles[i] | |
''' | |
下面是假设mapping order比现有的data_list长 | |
即有一些mapping order里的东西没有用到,所以需要删除一下0 | |
''' | |
res_labels = [] | |
for item in sorted_labels: | |
if item != 0: | |
res_labels.append(item) | |
res_handles = [] | |
for item in sorted_handles: | |
if item != 0: | |
res_handles.append(item) | |
ax.legend(res_handles, res_labels) | |
# ax.legend() | |
plt.tight_layout() | |
plt.show() |
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rcParams["font.family"] = "Times New Roman"
import pandas as pd
import numpy as np
num_bar = 6 #同种柱子的数量
Framework = ['IV4Rec+(UI)' if i%2==0 else 'IV4Rec+(I)' for i in range(2*num_bar)]
Ablation_part = ['w/o causal','w/o non-causal','w/o user','w/o item','w/o adaptive fusion','IV4Rec+']
Models = []
for item in Ablation_part:
Models.append(item)
Models.append(item)
data_list = [
{
"AUC":[0.5988,0.6081,0.6178,0.6230,0.6212,0.6275],
"MRR":[0.4848,0.4884,0.4932,0.5005,0.5014,0.5051],
"N":[i for i in range(6)],
'name':'IV4DIN+(UI)'
},
{
"AUC":[0.6087,0.6203,0.6178,0.6266,0.6255,0.6269],
"MRR":[0.4893,0.4999,0.4932,0.4995,0.4972,0.5092],
"N":[i for i in range(6)],
"name":"IV4DIN+(I)"
}
]
AUC = []
for i in range(6):
AUC.append(data_list[0]['AUC'][i])
AUC.append(data_list[1]['AUC'][i])
print(AUC)
a = pd.DataFrame({"AUC":AUC,
'Framework':Framework,
"Models":Models
}
)
bottom = 0.6163 #画图的起始点
sns.set_style({'font.family': 'Times New Roman'})
font_size = 17
sns.set(font_scale = 1.5)
sns.set_style('whitegrid')
plt.xticks(fontsize = font_size)
# hatches = ["/", "/", "o","o", "*", "*","\\","\\",'+','+']
hatches = ["/"]*num_bar + ['.']*num_bar
bar_plot = sns.barplot(x ='Framework',y='AUC',data=a, hue='Models', linewidth=3, palette='Paired')
start_y2 = 0.5950
end_y2 = 0.6290
bar_plot.set_ylim([start_y2, end_y2])
bar_plot.set_yticks(np.arange(start_y2, end_y2, 0.005))
bar_plot.set_ylabel("AUC")
bar_plot.axhline(y=bottom, color='cyan', linestyle= (0, (5, 1)),label='DIN')
bar_plot.set(xlabel=None) # turn off x label
for i, patch in enumerate(bar_plot.patches):
patch.set_hatch(hatches[i])
# if i < num_bar:
# patch.set_facecolor('#abdda4')
# else:
# patch.set_facecolor('#2b83ba')
# plt.legend(loc='best',title='Treatment')
plt.legend(loc='upper center',bbox_to_anchor=(0.5,1.9), borderaxespad=0.5)
plt.savefig('recombing.pdf',format='pdf',bbox_inches = 'tight')
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In order to avoid font type "type 3", add this script.
In order to use font "Times New Roman", add this script.