Skip to content

Instantly share code, notes, and snippets.

@fuad021
Last active November 28, 2020 05:45
Show Gist options
  • Save fuad021/3b0626dacd6f6bb154612f91555566ea to your computer and use it in GitHub Desktop.
Save fuad021/3b0626dacd6f6bb154612f91555566ea to your computer and use it in GitHub Desktop.
Python Implementation of Iterative Dichotomiser 3
# -*- coding: utf-8 -*-
# =============================================
# Author: Fuad Al Abir
# Date: 12 Feb 2020
# Problem: Decision Tree - Iterative Dichotomiser 3
# Course: CSE 3210
# =============================================
# =============================================
# LIBRARIES
# =============================================
from math import log
import csv
# =============================================
# CLASS DEFINITION
# =============================================
class Node:
def __init__(self, feature):
self.feature = feature
self.children = []
self.label = ''
# =======================================================================
# UTILITY FUNCTIONS
#
# tree: to generate the decision tree with training data
# predict: to predict test data with the decision generated by tree()
# entropy: to calculate entropy of each node
# gain: to calculate information gain
# sub_dataset: to create new sub table after choosing the node
# print_tree: to print the decision tree
# read_csv: to read the train_data and test_data
#
# =======================================================================
def tree(data, features):
label = [row[-1] for row in data]
if(len(set(label))) == 1:
node = Node('')
node.label = label[0]
return node
n = len(data[0]) - 1
gains = [0]*n
for col in range(n):
gains[col] = gain(data, col)
split = gains.index(max(gains))
node = Node(features[split])
feat = features[:split]+features[split+1:]
attr, dic = sub_dataset(data, split, delete = True)
for x in range(len(attr)):
child = tree(dic[attr[x]], feat)
node.children.append((attr[x], child))
return node
def predict(_tree, dataset):
print("\n==========================================\n\tPREDICTION ON TEST DATA:\n==========================================")
for i in range(len(dataset)):
flag = _tree.children[1][1].children[1][1].label
ds = dataset[i]
data = {_tree.feature: ds[0], feature[1]: ds[1], _tree.children[1][1].feature: ds[2], _tree.children[3][1].feature: ds[3]}
if(data[_tree.feature] == _tree.children[2][0]):
flag = _tree.children[1][1].children[0][1].label
elif(data[_tree.feature] == _tree.children[0][0]):
flag = _tree.children[1][1].children[1][1].label
elif(data[_tree.feature] == _tree.children[3][0]):
if(_tree.children[3][1].feature == _tree.children[3][1].children[0][0]):
flag = _tree.children[1][1].children[1][1].label
elif(_tree.children[3][1].feature == _tree.children[3][1].children[1][0]):
flag = _tree.children[1][1].children[0][1].label
elif(data[0] == _tree.children[1][0]):
if(_tree.children[1][1].feature == _tree.children[1][1].children[1][0]):
flag = _tree.children[1][1].children[1][1].label
elif(_tree.children[1][1].feature == _tree.children[1][1].children[0][0]):
flag = _tree.children[1][1].children[0][1].label
print(str(data) + '\nPrediction: ' + flag + '\n')
def entropy(s):
feature = list(set(s))
if len(feature) == 1: # if all are in same class
return 0
counts = [0, 0] # for binary classification
for i in range(2):
counts[i] = sum([1 for x in s if feature[i] == x])/(len(s)*1.0)
sums = 0
for c in counts:
sums += -1 * c * log(c, 2)
return sums
def gain(data, column):
feature_values, dictionary = sub_dataset(data, column, delete = False)
total_entropy = entropy([row[-1] for row in data])
for x in range(len(feature_values)):
ratio = len(dictionary[feature_values[x]])/(len(data)*1.0)
entr = entropy([row[-1] for row in dictionary[feature_values[x]]])
total_entropy -= ratio*entr
return total_entropy
def sub_dataset(data, col, delete):
dictionary = {}
column_data = [row[col] for row in data]
attr = list(set(column_data))
counts = [0]*len(attr)
row = len(data)
column = len(data[0])
for x in range(len(attr)):
for y in range(row):
if data[y][col] == attr[x]:
counts[x] += 1
for x in range(len(attr)):
dictionary[attr[x]]=[[0 for i in range(column)] for j in range(counts[x])]
position = 0
for y in range(row):
if data[y][col] == attr[x]:
if delete:
del data[y][col]
dictionary[attr[x]][position] = data[y]
position += 1
return attr, dictionary
def print_tree(node, level):
if level == 0:
print("\n=======================================================\n\tDECISION TREE BASED ON TRAIN DATA:\n=======================================================")
if node.label != '':
print(' '*level + node.label)
return
print(' '*level + node.feature)
for value, _node in node.children:
print(' '*(level+1) + value)
print_tree(_node, level+2)
def read_csv(filename):
lines = csv.reader(open(filename, 'r'))
dataset = list(lines)
feature = dataset.pop(0)
return dataset, feature
# =============================================
# MAIN FUNCTION
# =============================================
dataset, feature = read_csv('train_data.csv')
_tree = tree(dataset, feature)
print_tree(_tree, 0)
test_data, feature = read_csv('test_data.csv')
pred = predict(_tree, test_data)
Age Education Income Marital Status
36-55 Masters High Single
18-35 High School Low Single
36-55 Masters Low Single
18-35 Bachelors High Single
18-35 Bachelors High Married
>55 Bachelors High Single
Age Education Income Marital Status Buy Computer
36-55 Masters High Single Yes
18-35 High School Low Single No
36-55 Masters Low Single Yes
18-35 Bachelors High Single No
<18 High School Low Single Yes
18-35 Bachelors High Married No
36-55 Bachelors Low Married No
>55 Bachelors High Single Yes
36-55 Masters Low Married No
>55 Masters Low Married Yes
36-55 Masters High Single Yes
>55 Masters High Single Yes
<18 High School High Single No
36-55 Masters Low Single Yes
36-55 High School Low Single Yes
<18 High School Low Married Yes
18-35 Bachelors High Married No
>55 High School High Married Yes
>55 Bachelors Low Single Yes
36-55 High School High Married No
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment