Created
November 20, 2015 20:52
-
-
Save souldeux/99f71087c712c48e50b7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def determine_feature_importance(df): | |
#Determines the importance of individual features within a dataframe | |
#Grab header for all feature values excluding score & ids | |
features_list = df.columns.values[4::] | |
print "Features List: \n", features_list | |
#set X equal to all feature values, excluding Score & ID fields | |
X = df.values[:,4::] | |
#set y equal to all Score values | |
y = df.values[:,0] | |
#fit a random forest with near-default paramaters to determine feature importance | |
print '\nCreating Random Forest Classifier...\n' | |
forest = RandomForestClassifier(oob_score=True, n_estimators=10000) | |
print '\nFitting Random Forest Classifier...\n' | |
forest.fit(X,y) | |
feature_importance = forest.feature_importances_ | |
print feature_importance | |
#Make importances relative to maximum importance | |
print "\nMaximum feature importance is currently: ", feature_importance.max() | |
feature_importance = 100.0 * (feature_importance / feature_importance.max()) | |
print "\nNormalized feature importance: \n", feature_importance | |
print "\nNormalized maximum feature importance: \n", feature_importance.max() | |
print "\nTo do: set fi_threshold == max?" | |
print "\nTesting: setting fi_threshhold == 1" | |
fi_threshold=1 | |
#get indicies of all features over fi_threshold | |
important_idx = np.where(feature_importance > fi_threshold)[0] | |
print "\nRetrieved important_idx: ", important_idx | |
#create a list of all feature names above fi_threshold | |
important_features = features_list[important_idx] | |
print "\n", important_features.shape[0], "Important features(>", fi_threshold, "% of max importance:\n", important_features | |
#get sorted indices of important features | |
sorted_idx = np.argsort(feature_importance[important_idx])[::-1] | |
print "\nFeatures sorted by importance (DESC):\n", important_features[sorted_idx] | |
#generate plot | |
pos = np.arange(sorted_idx.shape[0]) + .5 | |
plt.subplot(1,2,2) | |
plt.barh(pos,feature_importance[important_idx][sorted_idx[::-1]],align='center') | |
plt.yticks(pos, important_features[sorted_idx[::-1]]) | |
plt.xlabel('Relative importance') | |
plt.ylabel('Variable importance') | |
plt.draw() | |
plt.show() | |
X = X[:, important_idx][:, sorted_idx] | |
return "Feature importance determined" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment