Skip to content

Instantly share code, notes, and snippets.

@hugobrilhante
Last active December 8, 2018 15:21
Show Gist options
  • Select an option

  • Save hugobrilhante/2f090e75fd1a7e6730603fc1e460d50d to your computer and use it in GitHub Desktop.

Select an option

Save hugobrilhante/2f090e75fd1a7e6730603fc1e460d50d to your computer and use it in GitHub Desktop.
ID3 implementation python
from collections import namedtuple
from dataclasses import dataclass, field
from typing import List
EC = namedtuple('EC', ['rc', 'hc', 'di', 'ga', 're'])
@dataclass(order=True)
class Node:
key: str
left: None = None
right: None = None
branches: List['Node'] = field(default_factory=list)
def add(self, node, cls_values):
if node.key in cls_values:
if self.left is None:
self.left = node
else:
self.right = node
else:
if node not in self.branches:
self.branches.append(node)
def is_same_cls(se, p):
"""
Checks if everyone is in the same class
"""
cls = getattr(se[0], p)
return all([getattr(e, p) == cls for e in se])
def disjunction(se, p):
"""
Returns the most commonly used
"""
return getattr(max(se, key=lambda e: getattr(e, p)), p)
def create_partition(se, p, v):
"""
Create a partition
"""
return [e for e in se if getattr(e, p) == v]
def get_target_values(se, target):
values = []
for e in se:
v = getattr(e, target)
if v not in values:
values.append(v)
return values
def induce_tree(se, ps, target):
cls_values = get_target_values(se, target)
if is_same_cls(se, target): # Assume that all are of the same class
return Node(key=cls_values[0])
elif not ps: # Assume the ps are empty
return Node(key=disjunction(se, target))
else:
# Select p and remove
p = ps[0]
# Creates a root for the current tree
subtree = Node(key=p)
for e in se:
# Get the value of p
v = getattr(e, p)
# Create branch for v
branch = Node(key=v)
# Create partition
partition = create_partition(se, p, v)
# Call the induce tree
result = induce_tree(partition, ps[1:], target)
# Add result to branch v
branch.add(result, cls_values)
# Add branch to subtree p
subtree.add(branch, cls_values)
return subtree
if __name__ == '__main__':
ps_values = ['re', 'hc', 'di', 'ga']
ec_1 = EC('alto', 'ruim', 'alta', 'nenhuma', '$0 a $15k')
ec_2 = EC('alto', 'desconhecida', 'alta', 'nenhuma', '$15k a $35k')
ec_3 = EC('moderado', 'desconhecida', 'baixa', 'nenhuma', '$15k a $35k')
ec_4 = EC('alto', 'desconhecida', 'baixa', 'nenhuma', '$0 a $15k')
ec_5 = EC('baixo', 'desconhecida', 'baixa', 'nenhuma', 'acima de 35k')
ec_6 = EC('baixo', 'desconhecida', 'baixa', 'adequada', 'acima de 35k')
ec_7 = EC('alto', 'ruim', 'baixa', 'nenhuma', '$0 a $15k')
ec_8 = EC('moderado', 'ruim', 'baixa', 'adequada', 'acima de 35k')
ec_9 = EC('baixo', 'boa', 'baixa', 'nenhuma', 'acima de 35k')
ec_10 = EC('baixo', 'boa', 'alta', 'adequada', 'acima de 35k')
ec_11 = EC('alto', 'boa', 'alta', 'nenhuma', '$0 a $15k')
ec_12 = EC('moderado', 'boa', 'alta', 'nenhuma', '$15k a $35k')
ec_13 = EC('baixo', 'boa', 'alta', 'nenhuma', 'acima de 35k')
ec_14 = EC('alto', 'ruim', 'alta', 'nenhuma', '$15k a $35k')
se_values = [ec_1, ec_2, ec_3, ec_4, ec_5, ec_6, ec_7, ec_8, ec_9, ec_10, ec_11, ec_12, ec_13, ec_14]
tree = induce_tree(se_values, ps_values, 'rc')
print(tree)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment