Skip to content

Instantly share code, notes, and snippets.

Created December 26, 2012 18:35
Show Gist options
  • Save anonymous/4382106 to your computer and use it in GitHub Desktop.
Save anonymous/4382106 to your computer and use it in GitHub Desktop.
Decision Tree: A CART Implementation
import sys
from math import sqrt
INFINITY = 100000000
MIN_SUPPORT = 5
MAX_DEPTH = 10
ALPHA = 20
TRAINING_RATIO = 0.9
TARGET_ID = 13
def print_info(depth, string):
print " " * depth + "- " + string
def predict(tree, instances):
ret = [tree.predict(instance) for instance in instances]
return ret
def loss(train, target_id):
avg = sum([row[target_id] for row in train]) / float(len(train))
ret = 0.0
for row in train:
ret += (row[target_id] - avg)**2
return ret
def evaluate(pred, exact, target_id, num_leaf, alpha = ALPHA):
ret = 0.0
for i in range(len(exact)):
ret += (exact[i][target_id] - pred[i])**2
ret += alpha * num_leaf
return ret
def find_internal_nodes(tree, internal_nodes):
if tree.left.is_leaf and tree.right.is_leaf:
internal_nodes.append(tree)
if not tree.left.is_leaf:
find_internal_nodes(tree.left, internal_nodes)
if not tree.right.is_leaf:
find_internal_nodes(tree.right, internal_nodes)
def prune(tree, train, target_id):
internal_nodes = []
find_internal_nodes(tree, internal_nodes)
while len(internal_nodes) > 0:
min_inc = INFINITY
pnode = None
for node in internal_nodes:
inc_loss = loss(node.train, target_id) - loss(node.left.train, target_id) - loss(node.right.train, target_id)
if inc_loss < min_inc:
min_inc = inc_loss
pnode = node
if min_inc < ALPHA:
pnode.is_leaf = True
tree.num_leaf -= 1
else:
return
internal_nodes = []
find_internal_nodes(tree, internal_nodes)
def print_tree(tree, depth = 0):
if tree.is_leaf:
print_info(depth, "[leaf] pred: %f, support: %d" % (tree.value, len(tree.train)))
else:
print_info(depth, "[internal] pred: %f, split_id: %d, split_value: %f" % (tree.value, tree.split_id, tree.split_value))
print_tree(tree.left, depth + 1)
print_tree(tree.right, depth + 1)
class DTree:
def __init__(self, train, target_id, depth = 0):
self.train = train
self.value = sum([row[target_id] for row in train]) / float(len(train))
if len(train) < MIN_SUPPORT or depth > MAX_DEPTH:
self.is_leaf = True
self.num_leaf = 1
return
self.split_id, self.split_value = self.split_tree(train, target_id)
if self.split_id == -1:
self.is_leaf = True
self.num_leaf = 1
return
self.is_leaf = False
self.left = DTree([row for row in train if row[self.split_id] <= self.split_value], target_id, depth + 1)
self.right = DTree([row for row in train if row[self.split_id] > self.split_value], target_id, depth + 1)
self.num_leaf = self.left.num_leaf + self.right.num_leaf
def split_tree(self, train, target_id):
split_id = -1
split_value = -1
min_loss = loss(train, target_id)
for attr_id in range(len(train[0])):
if attr_id != target_id:
attr_value = sorted(list(set([row[attr_id] for row in train])))
for i in range(len(attr_value) - 1):
v = (attr_value[i] + attr_value[i + 1]) / 2
left = loss([row for row in train if row[attr_id] <= v], target_id)
right = loss([row for row in train if row[attr_id] > v], target_id)
if left + right < min_loss:
split_id = attr_id
split_value = v
min_loss = left + right
return split_id, split_value
def predict(self, instance):
if self.is_leaf:
return self.value
if instance[self.split_id] <= self.split_value:
return self.left.predict(instance)
else:
return self.right.predict(instance)
if __name__ == "__main__":
filename = "housing/housing.data"
if len(sys.argv) > 1:
filename = sys.argv[1]
data = []
with open(filename, 'r') as fin:
data = [map(float, line.strip().split()) for line in fin]
split_point = int(len(data) * TRAINING_RATIO)
train = data[0:split_point]
test = data[split_point:]
target_id = TARGET_ID
tree = DTree(train, target_id)
prune(tree, train, target_id)
print_tree(tree)
pred_train = predict(tree, train)
pred_test = predict(tree, test)
print "train cost complexity criterion: %f" % evaluate(pred_train, train, target_id, tree.num_leaf)
print "train mean squared loss: %f" % (evaluate(pred_train, train, target_id, tree.num_leaf, 0) / len(train))
print "test cost complexity criterion: %f" % evaluate(pred_test, test, target_id, tree.num_leaf)
print "test mean squared loss: %f" % (evaluate(pred_test, test, target_id, tree.num_leaf, 0) / len(test))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment