Skip to content

Instantly share code, notes, and snippets.

@sakethramanujam
Last active December 10, 2019 05:31
Show Gist options
  • Save sakethramanujam/dd04dcefcd900af04387c6498848a127 to your computer and use it in GitHub Desktop.
Save sakethramanujam/dd04dcefcd900af04387c6498848a127 to your computer and use it in GitHub Desktop.
Saving Feature Importances
path_to_save_fi="./test/"
def _checkpath(*paths):
for path in paths:
path = str(path)
if not os.path.exists(path):
os.makedirs(path)
def get_fe_lable(indices):
fe_labels = [train_data.columns[index] for index in indices]
return fe_labels
def _imo_to_file(importances, labels, filename):
importances=np.sort(importances)[::-1]
zipped = list(zip(labels,importances))
df = pd.DataFrame(data=zipped,columns=['Feature Name', 'Importance Score'])
df.to_csv(filename, index=False)
def save_importances(model, targets):
for _, estimator, in enumerate(model.estimators_):
importances = estimator.feature_importances_
indices = np.argsort(importances)[::-1]
fig = plt.gcf()
fig.set_size_inches(16,10)
plt.stem(importances[indices], use_line_collection=True)
plt.ylim(0, max(importances)+0.005)
title = targets[_]
plt.title(f'Feature Importance Plot for prediction of {title}')
plt.xlabel('Feature Name')
plt.ylabel('Feature Importance Score')
labels = get_fe_lable(indices)
plt.xticks(range(train_data.shape[1]),labels, rotation='vertical',fontsize=12)
for _,index in enumerate(indices):
plt.annotate(round(importances[index],2),(_-0.05,importances[index]+0.001), fontsize=12)
path = path_to_save_fi
img_path = path+'images/'
csv_path = path+'csv/'
_checkpath(img_path, csv_path)
img_name = img_path+title+".png"
csv_name = csv_path+title+".csv"
plt.savefig(img_name, bbox_inches='tight')
plt.show()
_imo_to_file(importances, labels, csv_name)

The above code can be used as follows

path_to_save_fi = 'your/path/'
save_importances(model,y_test.columns)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment