Skip to content

Instantly share code, notes, and snippets.

@Eligijus112
Created April 12, 2021 06:14
Show Gist options
  • Save Eligijus112/cd8ccc53928ce2061299a9a3d39cf6dd to your computer and use it in GitHub Desktop.
Save Eligijus112/cd8ccc53928ce2061299a9a3d39cf6dd to your computer and use it in GitHub Desktop.
Node for a decision tree
# Data wrangling
import pandas as pd
# Array math
import numpy as np
# Quick value count calculator
from collections import Counter
class Node:
"""
Class for creating the nodes for a decision tree
"""
def __init__(
self,
Y: list,
X: pd.DataFrame,
min_samples_split=None,
max_depth=None,
depth=None,
node_type=None,
rule=None
):
# Saving the data to the node
self.Y = Y
self.X = X
# Saving the hyper parameters
self.min_samples_split = min_samples_split if min_samples_split else 20
self.max_depth = max_depth if max_depth else 5
# Default current depth of node
self.depth = depth if depth else 0
# Extracting all the features
self.features = list(self.X.columns)
# Type of node
self.node_type = node_type if node_type else 'root'
# Rule for spliting
self.rule = rule if rule else ""
# Calculating the counts of Y in the node
self.counts = Counter(Y)
# Getting the GINI impurity based on the Y distribution
self.gini_impurity = self.get_GINI()
# Sorting the counts and saving the final prediction of the node
counts_sorted = list(sorted(self.counts.items(), key=lambda item: item[1]))
# Getting the last item
yhat = None
if len(counts_sorted) > 0:
yhat = counts_sorted[-1][0]
# Saving to object attribute. This node will predict the class with the most frequent class
self.yhat = yhat
# Saving the number of observations in the node
self.n = len(Y)
# Initiating the left and right nodes as empty nodes
self.left = None
self.right = None
# Default values for splits
self.best_feature = None
self.best_value = None
@staticmethod
def GINI_impurity(y1_count: int, y2_count: int) -> float:
"""
Given the observations of a binary class calculate the GINI impurity
"""
# Ensuring the correct types
if y1_count is None:
y1_count = 0
if y2_count is None:
y2_count = 0
# Getting the total observations
n = y1_count + y2_count
# If n is 0 then we return the lowest possible gini impurity
if n == 0:
return 0.0
# Getting the probability to see each of the classes
p1 = y1_count / n
p2 = y2_count / n
# Calculating GINI
gini = 1 - (p1 ** 2 + p2 ** 2)
# Returning the gini impurity
return gini
@staticmethod
def ma(x: np.array, window: int) -> np.array:
"""
Calculates the moving average of the given list.
"""
return np.convolve(x, np.ones(window), 'valid') / window
def get_GINI(self):
"""
Function to calculate the GINI impurity of a node
"""
# Getting the 0 and 1 counts
y1_count, y2_count = self.counts.get(0, 0), self.counts.get(1, 0)
# Getting the GINI impurity
return self.GINI_impurity(y1_count, y2_count)
def best_split(self) -> tuple:
"""
Given the X features and Y targets calculates the best split
for a decision tree
"""
# Creating a dataset for spliting
df = self.X.copy()
df['Y'] = self.Y
# Getting the GINI impurity for the base input
GINI_base = self.get_GINI()
# Finding which split yields the best GINI gain
max_gain = 0
# Default best feature and split
best_feature = None
best_value = None
for feature in self.features:
# Droping missing values
Xdf = df.dropna().sort_values(feature)
# Sorting the values and getting the rolling average
xmeans = self.ma(Xdf[feature].unique(), 2)
for value in xmeans:
# Spliting the dataset
left_counts = Counter(Xdf[Xdf[feature]<value]['Y'])
right_counts = Counter(Xdf[Xdf[feature]>=value]['Y'])
# Getting the Y distribution from the dicts
y0_left, y1_left, y0_right, y1_right = left_counts.get(0, 0), left_counts.get(1, 0), right_counts.get(0, 0), right_counts.get(1, 0)
# Getting the left and right gini impurities
gini_left = self.GINI_impurity(y0_left, y1_left)
gini_right = self.GINI_impurity(y0_right, y1_right)
# Getting the obs count from the left and the right data splits
n_left = y0_left + y1_left
n_right = y0_right + y1_right
# Calculating the weights for each of the nodes
w_left = n_left / (n_left + n_right)
w_right = n_right / (n_left + n_right)
# Calculating the weighted GINI impurity
wGINI = w_left * gini_left + w_right * gini_right
# Calculating the GINI gain
GINIgain = GINI_base - wGINI
# Checking if this is the best split so far
if GINIgain > max_gain:
best_feature = feature
best_value = value
# Setting the best gain to the current one
max_gain = GINIgain
return (best_feature, best_value)
def grow_tree(self):
"""
Recursive method to create the decision tree
"""
# Making a df from the data
df = self.X.copy()
df['Y'] = self.Y
# If there is GINI to be gained, we split further
if (self.depth < self.max_depth) and (self.n >= self.min_samples_split):
# Getting the best split
best_feature, best_value = self.best_split()
if best_feature is not None:
# Saving the best split to the current node
self.best_feature = best_feature
self.best_value = best_value
# Getting the left and right nodes
left_df, right_df = df[df[best_feature]<=best_value].copy(), df[df[best_feature]>best_value].copy()
# Creating the left and right nodes
left = Node(
left_df['Y'].values.tolist(),
left_df[self.features],
depth=self.depth + 1,
max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
node_type='left_node',
rule=f"{best_feature} <= {round(best_value, 3)}"
)
self.left = left
self.left.grow_tree()
right = Node(
right_df['Y'].values.tolist(),
right_df[self.features],
depth=self.depth + 1,
max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
node_type='right_node',
rule=f"{best_feature} > {round(best_value, 3)}"
)
self.right = right
self.right.grow_tree()
def print_info(self, width=4):
"""
Method to print the infromation about the tree
"""
# Defining the number of spaces
const = int(self.depth * width ** 1.5)
spaces = "-" * const
if self.node_type == 'root':
print("Root")
else:
print(f"|{spaces} Split rule: {self.rule}")
print(f"{' ' * const} | GINI impurity of the node: {round(self.gini_impurity, 2)}")
print(f"{' ' * const} | Class distribution in the node: {dict(self.counts)}")
print(f"{' ' * const} | Predicted class: {self.yhat}")
def print_tree(self):
"""
Prints the whole tree from the current node to the bottom
"""
self.print_info()
if self.left is not None:
self.left.print_tree()
if self.right is not None:
self.right.print_tree()
def predict(self, X:pd.DataFrame):
"""
Batch prediction method
"""
predictions = []
for _, x in X.iterrows():
values = {}
for feature in self.features:
values.update({feature: x[feature]})
predictions.append(self.predict_obs(values))
return predictions
def predict_obs(self, values: dict) -> int:
"""
Method to predict the class given a set of features
"""
cur_node = self
while cur_node.depth < cur_node.max_depth:
# Traversing the nodes all the way to the bottom
best_feature = cur_node.best_feature
best_value = cur_node.best_value
if cur_node.n < cur_node.min_samples_split:
break
if (values.get(best_feature) < best_value):
if self.left is not None:
cur_node = cur_node.left
else:
if self.right is not None:
cur_node = cur_node.right
return cur_node.yhat
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment