This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
class DecisionTree: | |
def __init__(self, max_depth = 6, depth = 1): | |
self.max_depth = max_depth | |
self.depth = depth | |
self.left = None | |
self.right = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict(self, data): | |
return np.array([self.__flow_data_thru_tree(row) for _, row in data.iterrows()]) | |
def __flow_data_thru_tree(self, row): | |
if self.is_leaf_node: return self.probability | |
tree = self.left if row[self.split_feature] <= self.criteria else self.right | |
return tree.__flow_data_thru_tree(row) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@property | |
def is_leaf_node(self): return self.left is None | |
@property | |
def probability(self): | |
return self.data[self.target].value_counts().apply(lambda x: x/len(self.data)).tolist() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __init__(self, max_depth = 4, depth = 1): | |
self.max_depth = max_depth | |
self.depth = depth | |
self.left = None | |
self.right = None | |
def __create_branches(self): | |
self.left = DecisionTree(max_depth = self.max_depth, | |
depth = self.depth + 1) | |
self.right = DecisionTree(max_depth = self.max_depth, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __create_branches(self): | |
self.left = DecisionTree() | |
self.right = DecisionTree() | |
left_rows = self.data[self.data[self.split_feature] <= self.criteria] | |
right_rows = self.data[self.data[self.split_feature] > self.criteria] | |
self.left.fit(data = left_rows, target = self.target) | |
self.right.fit(data = right_rows, target = self.target) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __find_best_split(self): | |
best_split = {} | |
for col in self.independent: | |
information_gain, split = self.__find_best_split_for_column(col) | |
if split is None: continue | |
if not best_split or best_split["information_gain"] < information_gain: | |
best_split = {"split": split, | |
"col": col, | |
"information_gain": information_gain} | |
return best_split["split"], best_split["col"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __find_best_split_for_column(self, col): | |
x = self.data[col] | |
unique_values = x.unique() | |
if len(unique_values) == 1: return None, None | |
information_gain = None | |
split = None | |
for val in unique_values: | |
left = x <= val | |
right = x > val | |
left_data = self.data[left] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __calculate_impurity_score(self, data): | |
if data is None or data.empty: return 0 | |
p_i, _ = data.value_counts().apply(lambda x: x/len(data)).tolist() | |
return p_i * (1 - p_i) * 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __init__(self): | |
self.left = None | |
self.right = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DecisionTree: | |
def fit(self, data, target): | |
self.data = data | |
self.target = target | |
self.independent = self.data.columns.tolist() | |
self.independent.remove(target) | |
def predict(self, data): | |
return np.array([self.__flow_data_thru_tree(row) for row in data.values]) |
NewerOlder