Last active
March 5, 2019 02:23
-
-
Save crawles/ad6796c5ba0b97f6cf5c4baee42f532a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def permutation_importances(est, X_eval, y_eval, metric, features): | |
"""Column by column, shuffle values and observe effect on eval set. | |
source: http://explained.ai/rf-importance/index.html | |
A similar approach can be done during training. See "Drop-column importance" | |
in the above article.""" | |
def accuracy_metric(est, X, y): | |
"""TensorFlow estimator accuracy.""" | |
eval_input_fn = make_input_fn(X, | |
y=y, | |
shuffle=False, | |
n_epochs=1) | |
return est.evaluate(input_fn=eval_input_fn)['accuracy'] | |
baseline = metric(est, X_eval, y_eval) | |
imp = [] | |
for col in features: | |
save = X_eval[col].copy() | |
X_eval[col] = np.random.permutation(X_eval[col]) | |
m = metric(est, X_eval, y_eval) | |
X_eval[col] = save | |
imp.append(baseline - m) | |
return np.array(imp) | |
importances = permutation_importances(est, dfeval, y_eval, accuracy_metric, | |
dftrain.columns) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment