Last active
December 8, 2018 15:21
-
-
Save hugobrilhante/2f090e75fd1a7e6730603fc1e460d50d to your computer and use it in GitHub Desktop.
ID3 implementation python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from collections import namedtuple | |
| from dataclasses import dataclass, field | |
| from typing import List | |
| EC = namedtuple('EC', ['rc', 'hc', 'di', 'ga', 're']) | |
| @dataclass(order=True) | |
| class Node: | |
| key: str | |
| left: None = None | |
| right: None = None | |
| branches: List['Node'] = field(default_factory=list) | |
| def add(self, node, cls_values): | |
| if node.key in cls_values: | |
| if self.left is None: | |
| self.left = node | |
| else: | |
| self.right = node | |
| else: | |
| if node not in self.branches: | |
| self.branches.append(node) | |
| def is_same_cls(se, p): | |
| """ | |
| Checks if everyone is in the same class | |
| """ | |
| cls = getattr(se[0], p) | |
| return all([getattr(e, p) == cls for e in se]) | |
| def disjunction(se, p): | |
| """ | |
| Returns the most commonly used | |
| """ | |
| return getattr(max(se, key=lambda e: getattr(e, p)), p) | |
| def create_partition(se, p, v): | |
| """ | |
| Create a partition | |
| """ | |
| return [e for e in se if getattr(e, p) == v] | |
| def get_target_values(se, target): | |
| values = [] | |
| for e in se: | |
| v = getattr(e, target) | |
| if v not in values: | |
| values.append(v) | |
| return values | |
| def induce_tree(se, ps, target): | |
| cls_values = get_target_values(se, target) | |
| if is_same_cls(se, target): # Assume that all are of the same class | |
| return Node(key=cls_values[0]) | |
| elif not ps: # Assume the ps are empty | |
| return Node(key=disjunction(se, target)) | |
| else: | |
| # Select p and remove | |
| p = ps[0] | |
| # Creates a root for the current tree | |
| subtree = Node(key=p) | |
| for e in se: | |
| # Get the value of p | |
| v = getattr(e, p) | |
| # Create branch for v | |
| branch = Node(key=v) | |
| # Create partition | |
| partition = create_partition(se, p, v) | |
| # Call the induce tree | |
| result = induce_tree(partition, ps[1:], target) | |
| # Add result to branch v | |
| branch.add(result, cls_values) | |
| # Add branch to subtree p | |
| subtree.add(branch, cls_values) | |
| return subtree | |
| if __name__ == '__main__': | |
| ps_values = ['re', 'hc', 'di', 'ga'] | |
| ec_1 = EC('alto', 'ruim', 'alta', 'nenhuma', '$0 a $15k') | |
| ec_2 = EC('alto', 'desconhecida', 'alta', 'nenhuma', '$15k a $35k') | |
| ec_3 = EC('moderado', 'desconhecida', 'baixa', 'nenhuma', '$15k a $35k') | |
| ec_4 = EC('alto', 'desconhecida', 'baixa', 'nenhuma', '$0 a $15k') | |
| ec_5 = EC('baixo', 'desconhecida', 'baixa', 'nenhuma', 'acima de 35k') | |
| ec_6 = EC('baixo', 'desconhecida', 'baixa', 'adequada', 'acima de 35k') | |
| ec_7 = EC('alto', 'ruim', 'baixa', 'nenhuma', '$0 a $15k') | |
| ec_8 = EC('moderado', 'ruim', 'baixa', 'adequada', 'acima de 35k') | |
| ec_9 = EC('baixo', 'boa', 'baixa', 'nenhuma', 'acima de 35k') | |
| ec_10 = EC('baixo', 'boa', 'alta', 'adequada', 'acima de 35k') | |
| ec_11 = EC('alto', 'boa', 'alta', 'nenhuma', '$0 a $15k') | |
| ec_12 = EC('moderado', 'boa', 'alta', 'nenhuma', '$15k a $35k') | |
| ec_13 = EC('baixo', 'boa', 'alta', 'nenhuma', 'acima de 35k') | |
| ec_14 = EC('alto', 'ruim', 'alta', 'nenhuma', '$15k a $35k') | |
| se_values = [ec_1, ec_2, ec_3, ec_4, ec_5, ec_6, ec_7, ec_8, ec_9, ec_10, ec_11, ec_12, ec_13, ec_14] | |
| tree = induce_tree(se_values, ps_values, 'rc') | |
| print(tree) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment