Last active
October 17, 2019 18:09
-
-
Save thomasnield/3f751ae1e78ed12e29b04cc555dbb93c to your computer and use it in GitHub Desktop.
purity_vs_impurity_split.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import List | |
| class Point: | |
| def __init__(self, x: int, y: bool): | |
| self.x = x | |
| self.y = y | |
| points: List[Point] = [ | |
| Point(1, False), | |
| Point(2, True), | |
| Point(3, True), | |
| Point(4, True), | |
| Point(5, False), | |
| Point(6, True), | |
| Point(7, False), | |
| Point(8, False), | |
| Point(9, True), | |
| Point(10, False), | |
| ] | |
| def biased_purity_for_split(split_boundary: int) -> float: | |
| left_points = [p for p in points if p.x < split_boundary] | |
| right_points = [p for p in points if p.x >= split_boundary] | |
| left_purity = sum(1.0 for p in left_points if p.y) / (sum([1.0 for _ in left_points]) + .000001) | |
| right_purity = sum(1.0 for p in right_points if not p.y) / (sum([1.0 for _ in right_points]) + .000001) | |
| left_weight = len(left_points) / (len(left_points) + len(right_points)) | |
| right_weight = len(right_points) / (len(left_points) + len(right_points)) | |
| return left_purity * left_weight + right_purity * right_weight | |
| def find_best_biased_split() -> (int, float): | |
| best_split: int = points[0].x | |
| best_purity: float = 0.0 | |
| for p in points: | |
| new_purity = biased_purity_for_split(p.x) | |
| if new_purity > best_purity: | |
| best_purity = new_purity | |
| best_split = p.x | |
| return best_split, best_purity | |
| def impurity_for_split(split_boundary: int) -> float: | |
| left_points = [p for p in points if p.x < split_boundary] | |
| right_points = [p for p in points if p.x >= split_boundary] | |
| left_impurity = 1.0 - (sum(1.0 for p in left_points if p.y) / (sum([1.0 for _ in left_points]) + .000001))**2 - (sum(1.0 for p in left_points if not p.y) / (sum([1.0 for _ in left_points]) + .000001))**2 | |
| right_impurity = 1.0 - (sum(1.0 for p in right_points if p.y) / (sum([1.0 for _ in right_points]) + .000001))**2 - (sum(1.0 for p in right_points if not p.y) / (sum([1.0 for _ in left_points]) + .000001))**2 | |
| left_weight = len(left_points) / (len(left_points) + len(right_points)) | |
| right_weight = len(right_points) / (len(left_points) + len(right_points)) | |
| return left_impurity* left_weight + right_impurity * right_weight | |
| def find_best_impurity_split() -> (int, float): | |
| best_split: int = points[0].x | |
| best_purity: float = 0.0 | |
| for p in points: | |
| new_purity = biased_purity_for_split(p.x) | |
| if new_purity > best_purity: | |
| best_purity = new_purity | |
| best_split = p.x | |
| return best_split, best_purity | |
| print(find_best_biased_split()) | |
| print(find_best_impurity_split()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment