Created
December 11, 2019 09:24
-
-
Save sakethramanujam/05b9a5126e4981245d89f2eb88316071 to your computer and use it in GitHub Desktop.
Saving Feature Importances
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from matplotlib.lines import Line2D | |
from matplotlib.patches import Patch | |
def charmap(label): | |
SUB = str.maketrans("0123456789", "₀₁₂₃₄₅₆₇₈₉") | |
mu = chr(956) | |
sigma = chr(963) | |
beta = chr(946) | |
cmap = {'mean_x': mu+'$_{\ x}$','mean_y':mu+'$_{\ y}$','mean_sum': mu+'$_{\ I}$','mean_mag': mu+'$_{\ mag}$','mean_dir':mu+'$_{\ dir}$', | |
'std_x':sigma+'$_{\ x}$','std_y':sigma+'$_{\ y}$','std_sum':sigma+'$_{\ I}$','std_mag':sigma+'$_{\ mag}$','std_dir':sigma+'$_{\ dir}$', | |
'min_x': 'min$_{\ x}$', 'min_y': 'min$_{\ y}$','min_sum': 'min$_{\ I}$','min_mag': 'min$_{\ mag}$','min_dir': 'min$_{\ dir}$', | |
'max_x': 'max$_{\ x}$', 'max_y': 'max$_{\ y}$','max_sum': 'max$_{\ I}$','max_mag': 'max$_{\ mag}$','max_dir': 'max$_{\ dir}$', | |
'range_x': 'range$_{\ x}$', 'range_y': 'range$_{\ y}$','range_sum': 'range$_{\ I}$','range_mag': 'range$_{\ mag}$','range_dir': 'range$_{\ dir}$', | |
'skew_x': beta+'$_{1\ x}$', 'skew_y': beta+'$_{1\ y}$','skew_sum': beta+'$_{1\ I}$','skew_mag': beta+'$_{1\ mag}$','skew_dir': beta+'$_{1\ dir}$', | |
'kur_x': beta+'$_{2\ x}$', 'kur_y': beta+'$_{2\ y}$','kur_sum': beta+'$_{2\ I}$','kur_mag': beta+'$_{2\ mag}$','kur_dir': beta+'$_{2\ dir}$','mu': mu, 'sigma':sigma,'beta':beta | |
} | |
return cmap.get(label) | |
def _checkpath(*paths): | |
for path in paths: | |
path = str(path) | |
if not os.path.exists(path): | |
os.makedirs(path) | |
def _imo_to_file(importances, filename): | |
labels = train_data.columns | |
zipped = list(zip(labels,importances)) | |
df = pd.DataFrame(data=zipped,columns=['Feature', 'Importance Score']) | |
df.to_csv(filename, index=False) | |
def _get_hatch(i): | |
hatch_map={0:'x',1:"*",2:"+",3:"|",4:"."} | |
return hatch_map.get(i%5) | |
def _get_bar_color(i): | |
if i>=0 and i<5: | |
return "red" | |
elif i>=5 and i<10: | |
return "blue" | |
elif i>=10 and i<15: | |
return "orange" | |
elif i>=15 and i<20: | |
return "green" | |
elif i>=20 and i<25: | |
return "indigo" | |
elif i>=25 and i<30: | |
return "grey" | |
elif i>=30 and i<35: | |
return "brown" | |
def save_importances(model, targets): | |
for _, estimator, in enumerate(model.estimators_): | |
importances = estimator.feature_importances_ | |
indices = [_ for _ in range(len(importances))] | |
colors = ['red','green','blue','indigo','orange'] | |
fig = plt.gcf() | |
fig.set_size_inches(16,10) | |
for i in range(len(indices)): | |
plt.bar(i,importances[indices[i]],color=_get_bar_color(i), hatch=_get_hatch(i)) | |
plt.ylim(0, max(importances)+0.005) | |
plt.rcParams["font.family"] = "serif" | |
title = targets[_] | |
plt.title(f'Feature Importance Plot for Prediction of {title}') | |
plt.xlabel('Feature') | |
plt.ylabel('Feature Importance (Normalized Score)') | |
labels = get_fe_lable(indices) | |
plt.xticks(range(X_train.shape[1]),indices,fontsize=12) | |
for _,index in enumerate(indices): | |
plt.annotate(round(importances[index],2),(_-0.5,importances[index]+0.001), fontsize=11.5) | |
path = path_to_save_fi | |
img_path = path+'images/' | |
csv_path = path+'csv/' | |
_checkpath(img_path, csv_path) | |
img_name = img_path+title+".png" | |
csv_name = csv_path+title+".csv" | |
legend_elements = [Line2D([0], [0], marker='x', linestyle="", label='x'), | |
Line2D([0], [0], marker='*', linestyle="", label='y'), | |
Line2D([0], [0], marker='+', linestyle="", label='I (intensity)'), | |
Line2D([0], [0], marker='|', linestyle="", label='mag (magnitude)'), | |
Line2D([0], [0], marker='.', linestyle="", label='dir (direction)'), | |
Line2D([0], [0], color='red', lw=6, alpha=1,label=f'{charmap("mu")} (mean)'), | |
Line2D([0], [0], color='blue', lw=6, alpha=1,label=f'{charmap("sigma")} (standard deviation)'), | |
Line2D([0], [0], color='orange', lw=6, alpha=1,label='min (minimum)'), | |
Line2D([0], [0], color='green', lw=6, alpha=1,label='max (maximum)'), | |
Line2D([0], [0], color='indigo', lw=6, alpha=1,label='range'), | |
Line2D([0], [0], color='grey', lw=6, alpha=1,label=f'{charmap("beta")}$_1$ (skewness)'), | |
Line2D([0], [0], color='brown', lw=6, alpha=1,label=f'{charmap("beta")}$_2$ (kurtosis)'), | |
] | |
for i,label in enumerate(labels): | |
label = charmap(label) | |
legend_elements.append(Line2D([], [], color='white', lw=0, label=f'{i} - {label}')) | |
plt.legend(handles=legend_elements, bbox_to_anchor=(1, 1), fontsize=12,ncol=2).get_frame().set_edgecolor('b') | |
#plt.savefig(img_name, bbox_inches='tight') | |
plt.show() | |
#_imo_to_file(importances, csv_name) | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment