Created
May 10, 2021 12:21
-
-
Save Ethan00Si/86db6a87528f3caad0cee8bec3e7819b to your computer and use it in GitHub Desktop.
adjust order of legend using matplotlib.pyplot
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
fig, ax = plt.subplots() | |
data_list = [] | |
w_list = [3,4,2,1,5] | |
for i in w_list: | |
x = [x for x in range(0,100,10)] | |
y = [i*x for x in range(0, 100, 10)] | |
data_list.append((x,y)) | |
# plot results in a given order | |
mapping_order = {"w0":0, "w1":1, "w2":2, "w3":3, "w4":4, "w5":5, "w6":6} | |
marker = ['-,', '-o', '-^', '-v', '-s', '-p', '-*', '-h', '-+', '-x', '-1', '-2', '-3', '-4'] | |
colors = ['b', 'g', 'r', 'c', 'orangered', 'rosybrown', 'black', 'crimson', 'navy', 'chocolate', 'maroon'] | |
for i in range(len(data_list)): | |
x,y = data_list[i] | |
tmp = {mapping_order[x]:x for x in mapping_order} | |
file_name = 'w'+str(w_list[i]) | |
order = -1 | |
''' | |
注释处理根据读入数据的文件名排序画图顺序的情况 | |
''' | |
for item in mapping_order: | |
if item[1] == str(w_list[i]): | |
order = mapping_order[item] | |
break | |
marker_control = marker[order] | |
color_control = colors[order] | |
ax.plot(x, y, marker_control,c=color_control, label=file_name, linewidth=1, ms=8) | |
# plt.plot(range(1, 319, 10), X[i][0::10], c=colors[i], marker=marker[i], label=file_names[i], ls='--', lw=1, ms=5) | |
''' | |
调整legend中的顺序 | |
下面是假设legend中的顺序是乱序的,需要调整为mapping order的顺序 | |
''' | |
handles, labels = ax.get_legend_handles_labels() | |
sorted_handles, sorted_labels = [0]*len(mapping_order), [0]*len(mapping_order) | |
for i in range(len(labels)): | |
order = -1 | |
for item in mapping_order: | |
if item == labels[i][:len(item)]: | |
order = mapping_order[item] | |
break | |
sorted_labels[order] = labels[i] | |
sorted_handles[order] = handles[i] | |
''' | |
下面是假设mapping order比现有的data_list长 | |
即有一些mapping order里的东西没有用到,所以需要删除一下0 | |
''' | |
res_labels = [] | |
for item in sorted_labels: | |
if item != 0: | |
res_labels.append(item) | |
res_handles = [] | |
for item in sorted_handles: | |
if item != 0: | |
res_handles.append(item) | |
ax.legend(res_handles, res_labels) | |
# ax.legend() | |
plt.tight_layout() | |
plt.show() |
这是画折线图的代码。包括设置font famliy: Times New Roman.
画horizental line.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
fig, ax = plt.subplots()
data_list = [{"AUC":[0.6521,0.6513,0.6533,0.6561],
"N":['3','5','7','10'],'name':'IV4REC-DIN'},{"AUC":[0.6504,0.6503,0.6511,0.6574],
"N":['3','5','7','10'],
"name":"IV4REC-NRHUB"}]
plt.axhline(y=0.6512, color='navy', linestyle=(0, (1, 1)),label='DIN')
plt.axhline(y=0.6455, color='cyan', linestyle= (0, (5, 1)),label='NRHUB')
# plot results in a given order
mapping_order = {"w0":0, "w1":1, "w2":2, "w3":3, "w4":4, "w5":5, "w6":6}
marker = ['-o', '-^', '-v', '-s', '-p', '-*', '-h', '-+', '-x', '-1', '-2', '-3', '-4']
colors = ['b', 'g', 'r', 'c', 'orangered', 'rosybrown', 'black', 'crimson', 'navy', 'chocolate', 'maroon']
for i in range(len(data_list)):
x,y = data_list[i]['N'],data_list[i]['AUC']
# tmp = {mapping_order[x]:x for x in mapping_order}
file_name = data_list[i]['name']
order = i#-1
marker_control = marker[order]
color_control = colors[order]
ax.plot(x, y, marker_control,c=color_control, label=file_name, linewidth=1, ms=8)
# plt.plot(range(1, 319, 10), X[i][0::10], c=colors[i], marker=marker[i], label=file_names[i], ls='--', lw=1, ms=5)
font_size = 13
plt.rcParams.update({'font.size': font_size})
plt.xticks(fontsize = font_size)
plt.yticks(fontsize=font_size)
plt.rc('axes', titlesize=font_size) # fontsize of the axes title
plt.rc('axes', labelsize=font_size) # fontsize of the x and y labels
plt.xlabel('# clicked queires')
plt.ylabel("AUC")
ax.legend()
plt.tight_layout()
# plt.show()
plt.savefig('different_N.pdf')
用sns画柱状图
import matplotlib.pyplot as plt
import seaborn as sns
# plt.style.use('ggplot')
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rcParams["font.family"] = "Times New Roman"
# matplotlib.rcParams['font.sans-serif'] = ['SimHei']
# matplotlib.rcParams['axes.unicode_minus']=False
import pandas as pd
a = pd.DataFrame({"AUC":[0.6380,0.6462,0.6455,0.6567,0.6574,0.6452,0.6501,0.6512,0.6536,0.6561],
"Embedding":['only fitted values: $ \widehat{\mathcal{T}}_{u,i} $',
'only residuals: $\~{\mathcal{T}}_{u,i}$',
'original treatment: $\mathcal{T}_{u,i}$',
'concatenate fitted values & residuals',
'reconstructed treatment: $\mathcal{T}^{\mathrm{re}}_{u,i}$',
'only fitted values: $ \widehat{\mathcal{T}}_{u,i} $',
'only residuals: $\~{\mathcal{T}}_{u,i}$',
'original treatment: $\mathcal{T}_{u,i}$',
'concatenate fitted values & residuals',
'reconstructed treatment: $\mathcal{T}^{\mathrm{re}}_{u,i}$'
],
"models":["IV4Rec-NRHUB","IV4Rec-NRHUB","IV4Rec-NRHUB","IV4Rec-NRHUB","IV4Rec-NRHUB",
"IV4Rec-DIN","IV4Rec-DIN","IV4Rec-DIN","IV4Rec-DIN","IV4Rec-DIN"]}
)
font_size = 17
sns.set_style({'font.family': 'Times New Roman'})
# sns.set_style({'font.size': font_size})
plt.xticks(fontsize = font_size)
plt.yticks(fontsize=font_size)
plt.ylim(0.635,0.66)
sns.set(style='whitegrid')
sns.set(font_scale = 1.5)
hatches = ["/", "/","o","o", "*", "*","\\","\\",'+','+']
bar_plot = sns.barplot(x = 'models',y='AUC',data=a,hue='Embedding')
for i, patch in enumerate(bar_plot.patches):
# Blue bars first, then green bars
patch.set_hatch(hatches[i])
plt.legend(loc='best',bbox_to_anchor=(1.0,1.0), borderaxespad=0.5,title='Treatment')
plt.savefig('recombing.pdf',format='pdf',bbox_inches = 'tight')
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sim = np.load('sim.npy', allow_pickle=True)
sim_sort = np.sort(sim)
plt.rcParams["font.family"] = "Times New Roman"
font_size = 13
sns.set_style({'font.family': 'Times New Roman'})
# plt.xticks(fontsize = font_size)
# plt.yticks(fontsize=font_size)
# plt.ylim(0.635,0.66)
sns.set(style='whitegrid')
# sns.set(font_scale = 1.5)
# p = sns.lineplot(y=sim_sort[10::10000], x=np.arange(0,sim_sort[10::10000].shape[0],1))
p = sns.lineplot(y=sim_sort[10:], x=np.arange(0,sim_sort[10:].shape[0],1))
a = range(0, sim_sort[10:].shape[0], 500000)
b = ['0','0.5m','1.0m','1.5m','2.0m','2.5m','3.0m']
p.set_xticks(a, labels=b)
p.set_xlabel('item index', fontsize=font_size)
p.set_ylabel('cosine similarity', fontsize=font_size)
plt.tight_layout()
plt.savefig('similarity.pdf',format='pdf')
In order to avoid font type "type 3", add this script.
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
In order to use font "Times New Roman", add this script.
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"
import seaborn as sns
sns.set_style({'font.family': 'Times New Roman'})
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rcParams["font.family"] = "Times New Roman"
import pandas as pd
import numpy as np
num_bar = 6 #同种柱子的数量
Framework = ['IV4Rec+(UI)' if i%2==0 else 'IV4Rec+(I)' for i in range(2*num_bar)]
Ablation_part = ['w/o causal','w/o non-causal','w/o user','w/o item','w/o adaptive fusion','IV4Rec+']
Models = []
for item in Ablation_part:
Models.append(item)
Models.append(item)
data_list = [
{
"AUC":[0.5988,0.6081,0.6178,0.6230,0.6212,0.6275],
"MRR":[0.4848,0.4884,0.4932,0.5005,0.5014,0.5051],
"N":[i for i in range(6)],
'name':'IV4DIN+(UI)'
},
{
"AUC":[0.6087,0.6203,0.6178,0.6266,0.6255,0.6269],
"MRR":[0.4893,0.4999,0.4932,0.4995,0.4972,0.5092],
"N":[i for i in range(6)],
"name":"IV4DIN+(I)"
}
]
AUC = []
for i in range(6):
AUC.append(data_list[0]['AUC'][i])
AUC.append(data_list[1]['AUC'][i])
print(AUC)
a = pd.DataFrame({"AUC":AUC,
'Framework':Framework,
"Models":Models
}
)
bottom = 0.6163 #画图的起始点
sns.set_style({'font.family': 'Times New Roman'})
font_size = 17
sns.set(font_scale = 1.5)
sns.set_style('whitegrid')
plt.xticks(fontsize = font_size)
# hatches = ["/", "/", "o","o", "*", "*","\\","\\",'+','+']
hatches = ["/"]*num_bar + ['.']*num_bar
bar_plot = sns.barplot(x ='Framework',y='AUC',data=a, hue='Models', linewidth=3, palette='Paired')
start_y2 = 0.5950
end_y2 = 0.6290
bar_plot.set_ylim([start_y2, end_y2])
bar_plot.set_yticks(np.arange(start_y2, end_y2, 0.005))
bar_plot.set_ylabel("AUC")
bar_plot.axhline(y=bottom, color='cyan', linestyle= (0, (5, 1)),label='DIN')
bar_plot.set(xlabel=None) # turn off x label
for i, patch in enumerate(bar_plot.patches):
patch.set_hatch(hatches[i])
# if i < num_bar:
# patch.set_facecolor('#abdda4')
# else:
# patch.set_facecolor('#2b83ba')
# plt.legend(loc='best',title='Treatment')
plt.legend(loc='upper center',bbox_to_anchor=(0.5,1.9), borderaxespad=0.5)
plt.savefig('recombing.pdf',format='pdf',bbox_inches = 'tight')
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
代码里的mapping_order固定了不同的w对应的线的大小、顺序、marker的形状
不使用38行以后的代码调整legend的顺序的话


使用调整的话