Skip to content

Instantly share code, notes, and snippets.

@eldrin
Last active September 20, 2018 20:12
Show Gist options
  • Save eldrin/28c103a11489a32e38d7e60252c76d2f to your computer and use it in GitHub Desktop.
Save eldrin/28c103a11489a32e38d7e60252c76d2f to your computer and use it in GitHub Desktop.
visualize feature importance based on Rnadom Forest
from os.path import join, basename, splitext
import argparse
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt
# setup parser
parser = argparse.ArgumentParser()
parser.add_argument("X", help="filename of the feature file (`.npy`) to visualize")
parser.add_argument("y", help="filename of the label file (`.csv` or `.npy`) to visualize classes")
parser.add_argument("out_fn", help="filename for the outputing image (`.pdf`)")
parser.add_argument("--n-estimator", type=int, dest='n_estimator', default=20,
help="number of estimator used in random forest")
parser.add_argument("--n-jobs", dest='n_jobs', type=float, default=-1,
help="filename for the outputing image (`.pdf`)")
args = parser.parse_args()
# load the feature file
X = np.load(args.X)
# load the label file
ext = splitext(args.y)[1]
if ext == '.csv':
with open(args.y) as f:
y = np.array([l.split('\n')[0] for l in f])
elif ext == '.npy':
y = np.load(args.y)
else:
raise NotImplementedError('{} is not supported!'.format(ext))
# check shape
if X.shape[0] != len(y):
raise ValueError('Feature & label should have same number of samples!')
# initiate model
rf = RandomForestClassifier(args.n_estimator, n_jobs=args.n_jobs)
rf.fit(X, y)
# get the importance
I = rf.feature_importances_
# save the picture
plt.figure()
plt.plot(I)
plt.savefig(args.out_fn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment