Created
November 24, 2016 22:54
-
-
Save dat-boris/bb38683916c2a715d1091e4bfbd40c97 to your computer and use it in GitHub Desktop.
Wrapper class I use for simple decision tree exploration
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Classifier(): | |
def __init__(self, | |
df, depth=3, | |
source_cols=input_columns, target='delta_90', | |
use_classification=False): | |
""" | |
Given the model, let's train | |
:return: Classified model | |
""" | |
# Train classifier | |
start = time.time() | |
self.input_columns = source_cols | |
self.target = target | |
self.use_classification = use_classification | |
self.imputer = preprocessing.Imputer() | |
self.scaler = preprocessing.MinMaxScaler() | |
self.clf = None | |
if self.use_classification: | |
self.clf = tree.DecisionTreeClassifier(max_depth=depth) | |
else: | |
self.clf = tree.DecisionTreeRegressor(max_depth=depth) | |
input_data, output_target = self.clean_up_nan(df) | |
data_trans = self.imputer.fit_transform(input_data) | |
data_trans = self.scaler.fit_transform(data_trans) | |
self.transformed_data = (data_trans, output_target) | |
self.clf.fit(data_trans, output_target) | |
end = time.time() | |
print "Time to train full ML pipline: %0.2f secs" % (end - start) | |
def clean_up_nan(self, df): | |
test_data = df[self.input_columns + [self.target]].fillna(0) | |
X_input = test_data[self.input_columns] | |
Y_target = test_data[self.target] | |
print "Number of samples: {} -> {}".format(len(df), len(test_data)) | |
return X_input, Y_target | |
def show_importance(self): | |
# Note - number of time it is referenced | |
clf = self.clf | |
input_columns = self.input_columns | |
feature_importances = pd.Series(clf.feature_importances_, index=input_columns) | |
feature_importances.sort(ascending=False) | |
ax = feature_importances.plot(kind='bar') | |
ax.set(ylabel='Importance (Gini Coefficient)', title='Feature importances'); | |
def print_decision_tree(self, offset_unit=' '): | |
'''Plots textual representation of rules of a decision tree | |
tree: scikit-learn representation of tree | |
feature_names: list of feature names. They are set to f1,f2,f3,... if not specified | |
offset_unit: a string of offset of the conditional block''' | |
tree = self.clf | |
feature_names = self.input_columns | |
left = tree.tree_.children_left | |
right = tree.tree_.children_right | |
threshold = tree.tree_.threshold | |
value = tree.tree_.value | |
if feature_names is None: | |
features = ['f%d'%i for i in tree.tree_.feature] | |
else: | |
features = [feature_names[i] for i in tree.tree_.feature] | |
def recurse(left, right, threshold, features, node, depth=0): | |
offset = offset_unit*depth | |
if (threshold[node] != -2): | |
print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {") | |
if left[node] != -1: | |
recurse (left, right, threshold, features,left[node],depth+1) | |
print(offset+"} else {") | |
if right[node] != -1: | |
recurse (left, right, threshold, features,right[node],depth+1) | |
print(offset+"}") | |
else: | |
print(offset+"return " + str(value[node])) | |
recurse(left, right, threshold, features, 0,0) | |
def evaluate_model(self, test_data): | |
""" | |
Given the data, evaluate the model based on training the classifier | |
""" | |
clf = self.clf | |
X_input, Y_target = self.clean_up_nan(test_data) | |
# Transform test data | |
X_test_trans = self.imputer.transform(X_input) | |
X_test_trans = self.scaler.transform(X_test_trans) | |
# Predict! | |
Y_pred = clf.predict(X_test_trans) | |
Y_pred_prob = None | |
if self.use_classification: | |
Y_pred_prob = clf.predict_proba(X_test_trans) | |
#print 'Predictions:', Y_pred | |
#print 'Probabilities of class == 1:', Y_pred_prob[:, 1] * 100 | |
accuracy = metrics.accuracy_score(Y_target, Y_pred) * 100 | |
print('Accuracy on test set = {:.2f}%'.format(accuracy)) | |
print('Log-loss = {:.5f}'.format(metrics.log_loss(Y_target, Y_pred_prob))) | |
else: | |
# use r2 | |
mse = metrics.mean_squared_error(Y_target, Y_pred) | |
accuracy = metrics.r2_score(Y_target, Y_pred) | |
print('MSE = {:.5f}, r2 = {:2f}'.format(mse, accuracy)) | |
return accuracy |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment