Last active
September 20, 2018 20:12
-
-
Save eldrin/28c103a11489a32e38d7e60252c76d2f to your computer and use it in GitHub Desktop.
visualize feature importance based on Rnadom Forest
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from os.path import join, basename, splitext | |
import argparse | |
import numpy as np | |
from sklearn.ensemble import RandomForestClassifier | |
import matplotlib.pyplot as plt | |
# setup parser | |
parser = argparse.ArgumentParser() | |
parser.add_argument("X", help="filename of the feature file (`.npy`) to visualize") | |
parser.add_argument("y", help="filename of the label file (`.csv` or `.npy`) to visualize classes") | |
parser.add_argument("out_fn", help="filename for the outputing image (`.pdf`)") | |
parser.add_argument("--n-estimator", type=int, dest='n_estimator', default=20, | |
help="number of estimator used in random forest") | |
parser.add_argument("--n-jobs", dest='n_jobs', type=float, default=-1, | |
help="filename for the outputing image (`.pdf`)") | |
args = parser.parse_args() | |
# load the feature file | |
X = np.load(args.X) | |
# load the label file | |
ext = splitext(args.y)[1] | |
if ext == '.csv': | |
with open(args.y) as f: | |
y = np.array([l.split('\n')[0] for l in f]) | |
elif ext == '.npy': | |
y = np.load(args.y) | |
else: | |
raise NotImplementedError('{} is not supported!'.format(ext)) | |
# check shape | |
if X.shape[0] != len(y): | |
raise ValueError('Feature & label should have same number of samples!') | |
# initiate model | |
rf = RandomForestClassifier(args.n_estimator, n_jobs=args.n_jobs) | |
rf.fit(X, y) | |
# get the importance | |
I = rf.feature_importances_ | |
# save the picture | |
plt.figure() | |
plt.plot(I) | |
plt.savefig(args.out_fn) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment