Last active
December 3, 2018 20:47
-
-
Save kweimann/6672afc78b8f05cba75bd5b05fa783e0 to your computer and use it in GitHub Desktop.
Simple decision tree in Python for spam classification on spambase dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import logging | |
class DTNode: | |
def __init__(self, feature, threshold): | |
self.feature = feature | |
self.threshold = threshold | |
self.left = None | |
self.right = None | |
def predict(self, x): | |
if not self._is_initialized: | |
raise ValueError('node is not initialized') | |
if x[self.feature] < self.threshold: | |
return self.left.predict(x) | |
else: | |
return self.right.predict(x) | |
@property | |
def _is_initialized(self): | |
return self.left and self.right | |
class DTLeaf: | |
def __init__(self, y): | |
self.y = y | |
def predict(self, _): | |
return self.y | |
class DTForest: | |
def __init__(self, decision_trees): | |
self.decision_trees = decision_trees | |
def predict(self, x): | |
y_pred = [dt.predict(x) for dt in self.decision_trees] | |
most_frequent_y = np.argmax(np.bincount(y_pred)) | |
return most_frequent_y | |
def entropy(X): | |
probabilities = np.bincount(X) / len(X) | |
probabilities = probabilities[probabilities > 0] | |
return -np.sum(probabilities * np.log2(probabilities)) | |
def buildDT(X, y, n_features_sampled=None): | |
best_information_gain, node_data = 0, None | |
n_samples, n_features = X.shape | |
H_before_split = entropy(y) | |
if n_features_sampled: | |
features = np.random.choice(n_features, | |
min(n_features, max(1, n_features_sampled)), | |
replace=False) | |
else: | |
features = np.arange(n_features) | |
for feature in features: | |
X_feature = X[:, feature] | |
threshold = np.mean(X_feature) | |
left_idx = X_feature < threshold | |
right_idx = X_feature >= threshold | |
y_left = y[left_idx] | |
y_right = y[right_idx] | |
p_y_left = len(y_left) / n_samples | |
p_y_right = len(y_right) / n_samples | |
H_after_split = p_y_left * entropy(y_left) + p_y_right * entropy(y_right) | |
information_gain = H_before_split - H_after_split | |
if information_gain > best_information_gain: | |
best_information_gain = information_gain | |
node_data = feature, threshold, left_idx, y_left, right_idx, y_right | |
if not best_information_gain: | |
most_frequent_y = np.argmax(np.bincount(y)) | |
return DTLeaf(most_frequent_y) | |
else: | |
feature, threshold, left_idx, y_left, right_idx, y_right = node_data | |
node = DTNode(feature, threshold) | |
node.left = buildDT(X[left_idx], y_left, n_features_sampled) | |
node.right = buildDT(X[right_idx], y_right, n_features_sampled) | |
return node | |
def unison_shuffle(a, b): | |
if len(a) != len(b): | |
raise ValueError('array lengths do not match') | |
idx = np.random.permutation(len(a)) | |
return a[idx], b[idx] | |
def accuracy_score(y_true, y_pred): | |
if y_true.shape != y_pred.shape: | |
raise ValueError('array shapes do not match') | |
return np.sum(np.equal(y_true, y_pred)) / len(y_true) | |
if __name__ == '__main__': | |
np.random.seed(12345) | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s %(levelname)-8s %(name)-25.25s %(message)s') | |
# spambase dataset: https://archive.ics.uci.edu/ml/datasets/spambase | |
df = np.array(pd.read_csv('spambase.data', header=None)) | |
X, y = df[:, :-1], df[:, -1].astype(np.bool_) | |
X, y = unison_shuffle(X, y) | |
split = len(X) // 2 | |
X_train, y_train = X[:split], y[:split] | |
X_val, y_val = X[split:], y[split:] | |
_, n_features = X.shape | |
n_features_sampled = int(np.sqrt(n_features)) | |
forest_size = 20 | |
decision_trees = [] | |
for tree_idx in range(forest_size): | |
sampled_idx = np.random.randint(0, high=split, size=split) | |
X_bootstrap, y_bootstrap = X_train[sampled_idx], y_train[sampled_idx] | |
decision_tree = buildDT(X_bootstrap, y_bootstrap, n_features_sampled=n_features_sampled) | |
decision_trees.append(decision_tree) | |
logging.info('finished building tree no. %d' % tree_idx) | |
decision_forest = DTForest(decision_trees) | |
y_val_pred = np.empty(y_val.shape) | |
for i in range(len(X_val)): | |
y_val_pred[i] = decision_forest.predict(X_val[i]) | |
val_accuracy = 100 * accuracy_score(y_val, y_val_pred) | |
print('val accuracy: %.2f%%' % val_accuracy) | |
y_train_pred = np.empty(y_train.shape) | |
for i in range(len(X_train)): | |
y_train_pred[i] = decision_forest.predict(X_train[i]) | |
train_accuracy = 100 * accuracy_score(y_train, y_train_pred) | |
print('train accuracy: %.2f%%' % train_accuracy) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment