Skip to content

Instantly share code, notes, and snippets.

@sakethramanujam
Created December 11, 2019 09:24
Show Gist options
  • Save sakethramanujam/05b9a5126e4981245d89f2eb88316071 to your computer and use it in GitHub Desktop.
Save sakethramanujam/05b9a5126e4981245d89f2eb88316071 to your computer and use it in GitHub Desktop.
Saving Feature Importances
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
def charmap(label):
SUB = str.maketrans("0123456789", "₀₁₂₃₄₅₆₇₈₉")
mu = chr(956)
sigma = chr(963)
beta = chr(946)
cmap = {'mean_x': mu+'$_{\ x}$','mean_y':mu+'$_{\ y}$','mean_sum': mu+'$_{\ I}$','mean_mag': mu+'$_{\ mag}$','mean_dir':mu+'$_{\ dir}$',
'std_x':sigma+'$_{\ x}$','std_y':sigma+'$_{\ y}$','std_sum':sigma+'$_{\ I}$','std_mag':sigma+'$_{\ mag}$','std_dir':sigma+'$_{\ dir}$',
'min_x': 'min$_{\ x}$', 'min_y': 'min$_{\ y}$','min_sum': 'min$_{\ I}$','min_mag': 'min$_{\ mag}$','min_dir': 'min$_{\ dir}$',
'max_x': 'max$_{\ x}$', 'max_y': 'max$_{\ y}$','max_sum': 'max$_{\ I}$','max_mag': 'max$_{\ mag}$','max_dir': 'max$_{\ dir}$',
'range_x': 'range$_{\ x}$', 'range_y': 'range$_{\ y}$','range_sum': 'range$_{\ I}$','range_mag': 'range$_{\ mag}$','range_dir': 'range$_{\ dir}$',
'skew_x': beta+'$_{1\ x}$', 'skew_y': beta+'$_{1\ y}$','skew_sum': beta+'$_{1\ I}$','skew_mag': beta+'$_{1\ mag}$','skew_dir': beta+'$_{1\ dir}$',
'kur_x': beta+'$_{2\ x}$', 'kur_y': beta+'$_{2\ y}$','kur_sum': beta+'$_{2\ I}$','kur_mag': beta+'$_{2\ mag}$','kur_dir': beta+'$_{2\ dir}$','mu': mu, 'sigma':sigma,'beta':beta
}
return cmap.get(label)
def _checkpath(*paths):
for path in paths:
path = str(path)
if not os.path.exists(path):
os.makedirs(path)
def _imo_to_file(importances, filename):
labels = train_data.columns
zipped = list(zip(labels,importances))
df = pd.DataFrame(data=zipped,columns=['Feature', 'Importance Score'])
df.to_csv(filename, index=False)
def _get_hatch(i):
hatch_map={0:'x',1:"*",2:"+",3:"|",4:"."}
return hatch_map.get(i%5)
def _get_bar_color(i):
if i>=0 and i<5:
return "red"
elif i>=5 and i<10:
return "blue"
elif i>=10 and i<15:
return "orange"
elif i>=15 and i<20:
return "green"
elif i>=20 and i<25:
return "indigo"
elif i>=25 and i<30:
return "grey"
elif i>=30 and i<35:
return "brown"
def save_importances(model, targets):
for _, estimator, in enumerate(model.estimators_):
importances = estimator.feature_importances_
indices = [_ for _ in range(len(importances))]
colors = ['red','green','blue','indigo','orange']
fig = plt.gcf()
fig.set_size_inches(16,10)
for i in range(len(indices)):
plt.bar(i,importances[indices[i]],color=_get_bar_color(i), hatch=_get_hatch(i))
plt.ylim(0, max(importances)+0.005)
plt.rcParams["font.family"] = "serif"
title = targets[_]
plt.title(f'Feature Importance Plot for Prediction of {title}')
plt.xlabel('Feature')
plt.ylabel('Feature Importance (Normalized Score)')
labels = get_fe_lable(indices)
plt.xticks(range(X_train.shape[1]),indices,fontsize=12)
for _,index in enumerate(indices):
plt.annotate(round(importances[index],2),(_-0.5,importances[index]+0.001), fontsize=11.5)
path = path_to_save_fi
img_path = path+'images/'
csv_path = path+'csv/'
_checkpath(img_path, csv_path)
img_name = img_path+title+".png"
csv_name = csv_path+title+".csv"
legend_elements = [Line2D([0], [0], marker='x', linestyle="", label='x'),
Line2D([0], [0], marker='*', linestyle="", label='y'),
Line2D([0], [0], marker='+', linestyle="", label='I (intensity)'),
Line2D([0], [0], marker='|', linestyle="", label='mag (magnitude)'),
Line2D([0], [0], marker='.', linestyle="", label='dir (direction)'),
Line2D([0], [0], color='red', lw=6, alpha=1,label=f'{charmap("mu")} (mean)'),
Line2D([0], [0], color='blue', lw=6, alpha=1,label=f'{charmap("sigma")} (standard deviation)'),
Line2D([0], [0], color='orange', lw=6, alpha=1,label='min (minimum)'),
Line2D([0], [0], color='green', lw=6, alpha=1,label='max (maximum)'),
Line2D([0], [0], color='indigo', lw=6, alpha=1,label='range'),
Line2D([0], [0], color='grey', lw=6, alpha=1,label=f'{charmap("beta")}$_1$ (skewness)'),
Line2D([0], [0], color='brown', lw=6, alpha=1,label=f'{charmap("beta")}$_2$ (kurtosis)'),
]
for i,label in enumerate(labels):
label = charmap(label)
legend_elements.append(Line2D([], [], color='white', lw=0, label=f'{i} - {label}'))
plt.legend(handles=legend_elements, bbox_to_anchor=(1, 1), fontsize=12,ncol=2).get_frame().set_edgecolor('b')
#plt.savefig(img_name, bbox_inches='tight')
plt.show()
#_imo_to_file(importances, csv_name)
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment