Created
January 1, 2019 02:20
-
-
Save Alescontrela/f914f12147421ddaa5719304c8a2feee to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DecisionTree(object): | |
def __init__(self, x, y, idxs = None, min_leaf = 5): | |
""" | |
Create a decision tree by computing what feature from the observation x to | |
perform the current split on. Best feature is computed as that which results | |
in the minimum standard deviation across the input examples. Split value | |
is the value of the best feature at which to perform the split. | |
""" | |
# ids of data samples to use for the creation of the current decision tree | |
if idxs is None: idxs = np.arange(len(y)) | |
# get observations, labels, and minimum number of samples per leaf | |
self.x, self.y, self.idxs, self.min_leaf = x, y, idxs, min_leaf | |
self.n, self.c = len(idxs), x.shape[1] # num. samples and features per observation | |
self.val = np.mean(y[idxs]) # Value for current split equals the | |
self.score = float('inf') # score of decision tree | |
self.find_varsplit() | |
def find_varsplit(self): | |
""" | |
Determine the feature to perform the current split on if such a feature exists. If so, | |
create two children: one for values lower than the split value and the other for values higher | |
than the split val. | |
""" | |
# test all features to find which one returns the lowest standard deviation | |
for i in range(self.c): self.find_better_split(i) | |
if self.score == float('inf'): return # no split was found | |
x = self.split_col | |
lhs = np.nonzero(x<=self.split)[0]; rhs = np.nonzero(x>self.split)[0] | |
self.lhs = DecisionTree(self.x, self.y, self.idxs[lhs]) | |
self.rhs = DecisionTree(self.x, self.y, self.idxs[rhs]) | |
def find_better_split(self, var_idx): | |
""" | |
Determine whether the current feature (var_idx) is the best feature to perform the split with. | |
If the aggregated standard deviation of the current feature is the lowest, update the current | |
split score and split value. | |
""" | |
x,y = self.x.values[self.idxs, var_idx], self.y[self.idxs] | |
sort_idx = np.argsort(x.T).T # sort samples by feature value | |
sort_y, sort_x = y[sort_idx], x[sort_idx] | |
rhs_cnt, rhs_sum, rhs_sum2 = self.n, np.sum(sort_y), (sort_y**2).sum() | |
lhs_cnt, lhs_sum, lhs_sum2 = 0., 0., 0. | |
for i in range(0, self.n-self.min_leaf-1): | |
xi, yi = sort_x[i], sort_y[i] | |
lhs_cnt += 1; rhs_cnt -= 1 | |
lhs_sum += yi; rhs_sum -= yi | |
lhs_sum2 += yi**2; rhs_sum2 -= yi**2 | |
if i < self.min_leaf or xi == sort_x[i+1]: continue | |
# Calculate the standard deviation of the labels less than and greater than the current x value | |
lhs_std = DecisionTree.std_agg(lhs_cnt, lhs_sum, lhs_sum2) | |
rhs_std = DecisionTree.std_agg(rhs_cnt, rhs_sum, rhs_sum2) | |
curr_score = lhs_std*lhs_cnt + rhs_std*rhs_cnt | |
if curr_score < self.score: | |
self.var_idx, self.score, self.split = var_idx, curr_score, xi | |
@property | |
def split_name(self): return self.x.columns[self.var_idx] | |
@property | |
def split_col(self): return self.x.values[self.idxs, self.var_idx] | |
@property | |
def is_leaf(self): return self.score == float('inf') | |
def __repr__(self): | |
s = f'n: {self.n}--val: {self.val}' | |
print(self.is_leaf) | |
if not self.is_leaf: | |
s+= f'--score:{self.score}--split: {self.split}--var: {self.split_name}--var_idx: {self.var_idx}' | |
return s | |
def predict(self, x, debug): | |
""" | |
Form predictions for input observation by recursing through decision tree until leaf encountered. | |
The split value for this leaf is the prediction. | |
""" | |
return np.array([self.predict_row(xi, debug) for xi in x]) | |
def predict_row(self, xi, debug): | |
""" | |
Predict value of input. debug parameter specifies whether split decision should be broadcasted. | |
""" | |
if self.is_leaf: return self.val | |
if debug: print(self.split_name, end = " ") | |
t = self.lhs if xi[self.var_idx] <= self.split else self.rhs | |
if debug: | |
if t==self.lhs: | |
print("less than", end = " ") | |
else: | |
print("greater than", end = " ") | |
if debug: print(self.split) | |
return t.predict_row(xi, debug) | |
@staticmethod | |
def std_agg(cnt, s1, s2): | |
""" | |
Compute the aggregated standard deviation of the value s | |
""" | |
return math.sqrt(np.abs((s2/cnt) - (s1/cnt)**2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment