Created
December 27, 2018 01:12
-
-
Save amaarora/0b1b9c9db4ba3b5bd675ec4b746f1072 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TreeEnsemble(): | |
def __init__(self, x, y, n_trees, sample_sz, min_leaf=5): | |
np.random.seed(42) | |
self.x,self.y,self.sample_sz,self.min_leaf = x,y,sample_sz,min_leaf | |
self.trees = [self.create_tree() for i in range(n_trees)] | |
def create_tree(self): | |
idxs = np.random.permutation(len(self.y))[:self.sample_sz] | |
return DecisionTree(self.x.iloc[idxs], self.y[idxs], | |
idxs=np.array(range(self.sample_sz)), min_leaf=self.min_leaf) | |
def predict(self, x): | |
return np.mean([t.predict(x) for t in self.trees], axis=0) | |
def std_agg(cnt, s1, s2): return math.sqrt((s2/cnt) - (s1/cnt)**2) | |
class DecisionTree(): | |
def __init__(self, x, y, idxs, min_leaf=5): | |
self.x,self.y,self.idxs,self.min_leaf = x,y,idxs,min_leaf | |
self.n,self.c = len(idxs), x.shape[1] | |
self.val = np.mean(y[idxs]) | |
self.score = float('inf') | |
self.find_varsplit() | |
def find_varsplit(self): | |
for i in range(self.c): self.find_better_split(i) | |
if self.score == float('inf'): return | |
x = self.split_col | |
lhs = np.nonzero(x<=self.split)[0] | |
rhs = np.nonzero(x>self.split)[0] | |
self.lhs = DecisionTree(self.x, self.y, self.idxs[lhs]) | |
self.rhs = DecisionTree(self.x, self.y, self.idxs[rhs]) | |
def find_better_split(self, var_idx): | |
x,y = self.x.values[self.idxs,var_idx], self.y[self.idxs] | |
sort_idx = np.argsort(x) | |
sort_y,sort_x = y[sort_idx], x[sort_idx] | |
rhs_cnt,rhs_sum,rhs_sum2 = self.n, sort_y.sum(), (sort_y**2).sum() | |
lhs_cnt,lhs_sum,lhs_sum2 = 0,0.,0. | |
for i in range(0,self.n-self.min_leaf): | |
xi,yi = sort_x[i],sort_y[i] | |
lhs_cnt += 1; rhs_cnt -= 1 | |
lhs_sum += yi; rhs_sum -= yi | |
lhs_sum2 += yi**2; rhs_sum2 -= yi**2 | |
if i<self.min_leaf-1 or xi==sort_x[i+1]: | |
continue | |
lhs_std = std_agg(lhs_cnt, lhs_sum, lhs_sum2) | |
rhs_std = std_agg(rhs_cnt, rhs_sum, rhs_sum2) | |
curr_score = lhs_std*lhs_cnt + rhs_std*rhs_cnt | |
if curr_score<self.score: | |
self.var_idx,self.score,self.split = var_idx,curr_score,xi | |
@property | |
def split_name(self): return self.x.columns[self.var_idx] | |
@property | |
def split_col(self): return self.x.values[self.idxs,self.var_idx] | |
@property | |
def is_leaf(self): return self.score == float('inf') | |
def __repr__(self): | |
s = f'n: {self.n}; val:{self.val}' | |
if not self.is_leaf: | |
s += f'; score:{self.score}; split:{self.split}; var:{self.split_name}' | |
return s | |
def predict(self, x): | |
return np.array([self.predict_row(xi) for xi in x]) | |
def predict_row(self, xi): | |
if self.is_leaf: return self.val | |
t = self.lhs if xi[self.var_idx]<=self.split else self.rhs | |
return t.predict_row(xi) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment