Skip to content

Instantly share code, notes, and snippets.

@Ethan00Si
Created May 10, 2021 12:21
Show Gist options
  • Save Ethan00Si/86db6a87528f3caad0cee8bec3e7819b to your computer and use it in GitHub Desktop.
Save Ethan00Si/86db6a87528f3caad0cee8bec3e7819b to your computer and use it in GitHub Desktop.
adjust order of legend using matplotlib.pyplot
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
data_list = []
w_list = [3,4,2,1,5]
for i in w_list:
x = [x for x in range(0,100,10)]
y = [i*x for x in range(0, 100, 10)]
data_list.append((x,y))
# plot results in a given order
mapping_order = {"w0":0, "w1":1, "w2":2, "w3":3, "w4":4, "w5":5, "w6":6}
marker = ['-,', '-o', '-^', '-v', '-s', '-p', '-*', '-h', '-+', '-x', '-1', '-2', '-3', '-4']
colors = ['b', 'g', 'r', 'c', 'orangered', 'rosybrown', 'black', 'crimson', 'navy', 'chocolate', 'maroon']
for i in range(len(data_list)):
x,y = data_list[i]
tmp = {mapping_order[x]:x for x in mapping_order}
file_name = 'w'+str(w_list[i])
order = -1
'''
注释处理根据读入数据的文件名排序画图顺序的情况
'''
for item in mapping_order:
if item[1] == str(w_list[i]):
order = mapping_order[item]
break
marker_control = marker[order]
color_control = colors[order]
ax.plot(x, y, marker_control,c=color_control, label=file_name, linewidth=1, ms=8)
# plt.plot(range(1, 319, 10), X[i][0::10], c=colors[i], marker=marker[i], label=file_names[i], ls='--', lw=1, ms=5)
'''
调整legend中的顺序
下面是假设legend中的顺序是乱序的,需要调整为mapping order的顺序
'''
handles, labels = ax.get_legend_handles_labels()
sorted_handles, sorted_labels = [0]*len(mapping_order), [0]*len(mapping_order)
for i in range(len(labels)):
order = -1
for item in mapping_order:
if item == labels[i][:len(item)]:
order = mapping_order[item]
break
sorted_labels[order] = labels[i]
sorted_handles[order] = handles[i]
'''
下面是假设mapping order比现有的data_list长
即有一些mapping order里的东西没有用到,所以需要删除一下0
'''
res_labels = []
for item in sorted_labels:
if item != 0:
res_labels.append(item)
res_handles = []
for item in sorted_handles:
if item != 0:
res_handles.append(item)
ax.legend(res_handles, res_labels)
# ax.legend()
plt.tight_layout()
plt.show()
@Ethan00Si
Copy link
Author

代码里的mapping_order固定了不同的w对应的线的大小、顺序、marker的形状

不使用38行以后的代码调整legend的顺序的话
before
使用调整的话
after

@Ethan00Si
Copy link
Author

Ethan00Si commented Oct 16, 2021

这是画折线图的代码。包括设置font famliy: Times New Roman.
画horizental line.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

fig, ax = plt.subplots()

data_list = [{"AUC":[0.6521,0.6513,0.6533,0.6561],
    "N":['3','5','7','10'],'name':'IV4REC-DIN'},{"AUC":[0.6504,0.6503,0.6511,0.6574],
    "N":['3','5','7','10'],
    "name":"IV4REC-NRHUB"}]

plt.axhline(y=0.6512, color='navy', linestyle=(0, (1, 1)),label='DIN')
plt.axhline(y=0.6455, color='cyan', linestyle= (0, (5, 1)),label='NRHUB')
# plot results in a given order
mapping_order = {"w0":0, "w1":1, "w2":2, "w3":3, "w4":4, "w5":5, "w6":6}
marker = ['-o', '-^', '-v', '-s', '-p', '-*', '-h', '-+', '-x', '-1', '-2', '-3', '-4']
colors = ['b', 'g', 'r', 'c', 'orangered', 'rosybrown', 'black', 'crimson', 'navy', 'chocolate', 'maroon']


for i in range(len(data_list)):
    x,y = data_list[i]['N'],data_list[i]['AUC']
    # tmp = {mapping_order[x]:x for x in mapping_order}
    file_name = data_list[i]['name']
    order = i#-1
   
    marker_control = marker[order]
    color_control = colors[order]
    ax.plot(x, y, marker_control,c=color_control, label=file_name, linewidth=1, ms=8)
    # plt.plot(range(1, 319, 10), X[i][0::10], c=colors[i], marker=marker[i], label=file_names[i], ls='--', lw=1, ms=5)

font_size = 13
plt.rcParams.update({'font.size': font_size})
plt.xticks(fontsize = font_size)
plt.yticks(fontsize=font_size)
plt.rc('axes', titlesize=font_size)     # fontsize of the axes title
plt.rc('axes', labelsize=font_size)    # fontsize of the x and y labels
plt.xlabel('# clicked queires')
plt.ylabel("AUC")
ax.legend()
plt.tight_layout()
# plt.show()
plt.savefig('different_N.pdf')

截屏2021-10-16 下午7 42 36

@Ethan00Si
Copy link
Author

Ethan00Si commented Oct 16, 2021

用sns画柱状图

import matplotlib.pyplot as plt
import seaborn as sns
# plt.style.use('ggplot')
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rcParams["font.family"] = "Times New Roman"
# matplotlib.rcParams['font.sans-serif'] = ['SimHei']
# matplotlib.rcParams['axes.unicode_minus']=False
import pandas as pd


a = pd.DataFrame({"AUC":[0.6380,0.6462,0.6455,0.6567,0.6574,0.6452,0.6501,0.6512,0.6536,0.6561],
    "Embedding":['only fitted values: $ \widehat{\mathcal{T}}_{u,i} $',
    'only residuals: $\~{\mathcal{T}}_{u,i}$',
    'original treatment: $\mathcal{T}_{u,i}$',
    'concatenate fitted values & residuals',
    'reconstructed treatment: $\mathcal{T}^{\mathrm{re}}_{u,i}$',
    'only fitted values: $ \widehat{\mathcal{T}}_{u,i} $',
    'only residuals: $\~{\mathcal{T}}_{u,i}$',
    'original treatment: $\mathcal{T}_{u,i}$',
    'concatenate fitted values & residuals',
    'reconstructed treatment: $\mathcal{T}^{\mathrm{re}}_{u,i}$'
    ],
    "models":["IV4Rec-NRHUB","IV4Rec-NRHUB","IV4Rec-NRHUB","IV4Rec-NRHUB","IV4Rec-NRHUB",
    "IV4Rec-DIN","IV4Rec-DIN","IV4Rec-DIN","IV4Rec-DIN","IV4Rec-DIN"]}
)

font_size = 17
sns.set_style({'font.family': 'Times New Roman'})
# sns.set_style({'font.size': font_size})
plt.xticks(fontsize = font_size)
plt.yticks(fontsize=font_size)
plt.ylim(0.635,0.66)
sns.set(style='whitegrid')
sns.set(font_scale = 1.5)
hatches = ["/", "/","o","o", "*", "*","\\","\\",'+','+']
bar_plot = sns.barplot(x = 'models',y='AUC',data=a,hue='Embedding')
for i, patch in enumerate(bar_plot.patches):
    # Blue bars first, then green bars
    patch.set_hatch(hatches[i])

plt.legend(loc='best',bbox_to_anchor=(1.0,1.0), borderaxespad=0.5,title='Treatment')
plt.savefig('recombing.pdf',format='pdf',bbox_inches = 'tight')

截屏2021-10-16 下午7 44 26

@Ethan00Si
Copy link
Author

Ethan00Si commented Feb 16, 2022

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sim = np.load('sim.npy', allow_pickle=True)
sim_sort = np.sort(sim)

plt.rcParams["font.family"] = "Times New Roman"
font_size = 13
sns.set_style({'font.family': 'Times New Roman'})
# plt.xticks(fontsize = font_size)
# plt.yticks(fontsize=font_size)
# plt.ylim(0.635,0.66)

sns.set(style='whitegrid')
# sns.set(font_scale = 1.5)

# p = sns.lineplot(y=sim_sort[10::10000], x=np.arange(0,sim_sort[10::10000].shape[0],1))

p = sns.lineplot(y=sim_sort[10:], x=np.arange(0,sim_sort[10:].shape[0],1))

a = range(0, sim_sort[10:].shape[0], 500000)
b = ['0','0.5m','1.0m','1.5m','2.0m','2.5m','3.0m']

p.set_xticks(a, labels=b)
p.set_xlabel('item index', fontsize=font_size)
p.set_ylabel('cosine similarity', fontsize=font_size)
plt.tight_layout()
plt.savefig('similarity.pdf',format='pdf')

similarity

@Ethan00Si
Copy link
Author

In order to avoid font type "type 3", add this script.

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In order to use font "Times New Roman", add this script.

import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"
import seaborn as sns
sns.set_style({'font.family': 'Times New Roman'})

@Ethan00Si
Copy link
Author

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rcParams["font.family"] = "Times New Roman"
import pandas as pd
import numpy as np

num_bar = 6 #同种柱子的数量

Framework = ['IV4Rec+(UI)' if i%2==0 else 'IV4Rec+(I)'  for i in range(2*num_bar)]
Ablation_part = ['w/o causal','w/o non-causal','w/o user','w/o item','w/o adaptive fusion','IV4Rec+']
Models = []
for item in Ablation_part:
    Models.append(item)
    Models.append(item)


data_list = [
        {
        "AUC":[0.5988,0.6081,0.6178,0.6230,0.6212,0.6275],
        "MRR":[0.4848,0.4884,0.4932,0.5005,0.5014,0.5051],
        "N":[i for i in range(6)],
        'name':'IV4DIN+(UI)'
        },

        {
        "AUC":[0.6087,0.6203,0.6178,0.6266,0.6255,0.6269],
        "MRR":[0.4893,0.4999,0.4932,0.4995,0.4972,0.5092],
        "N":[i for i in range(6)],
        "name":"IV4DIN+(I)"
        }
    ]

AUC = []
for i in range(6):
    AUC.append(data_list[0]['AUC'][i])
    AUC.append(data_list[1]['AUC'][i])
print(AUC)
a = pd.DataFrame({"AUC":AUC,
    'Framework':Framework,
    "Models":Models
    }
)

bottom = 0.6163 #画图的起始点


sns.set_style({'font.family': 'Times New Roman'})
font_size = 17
sns.set(font_scale = 1.5)
sns.set_style('whitegrid')
plt.xticks(fontsize = font_size)

# hatches = ["/", "/", "o","o", "*", "*","\\","\\",'+','+']
hatches = ["/"]*num_bar + ['.']*num_bar
bar_plot = sns.barplot(x ='Framework',y='AUC',data=a, hue='Models', linewidth=3, palette='Paired')

start_y2 = 0.5950
end_y2 = 0.6290
bar_plot.set_ylim([start_y2, end_y2])
bar_plot.set_yticks(np.arange(start_y2, end_y2, 0.005))
bar_plot.set_ylabel("AUC")

bar_plot.axhline(y=bottom, color='cyan', linestyle= (0, (5, 1)),label='DIN')
bar_plot.set(xlabel=None) # turn off x label
for i, patch in enumerate(bar_plot.patches):
    patch.set_hatch(hatches[i])
    # if i < num_bar:
    #     patch.set_facecolor('#abdda4')
    # else:
    #     patch.set_facecolor('#2b83ba')  

# plt.legend(loc='best',title='Treatment')
plt.legend(loc='upper center',bbox_to_anchor=(0.5,1.9), borderaxespad=0.5)
plt.savefig('recombing.pdf',format='pdf',bbox_inches = 'tight')

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment