Skip to content

Instantly share code, notes, and snippets.

@thomasnield
Last active October 17, 2019 18:09
Show Gist options
  • Select an option

  • Save thomasnield/3f751ae1e78ed12e29b04cc555dbb93c to your computer and use it in GitHub Desktop.

Select an option

Save thomasnield/3f751ae1e78ed12e29b04cc555dbb93c to your computer and use it in GitHub Desktop.
purity_vs_impurity_split.py
from typing import List
class Point:
def __init__(self, x: int, y: bool):
self.x = x
self.y = y
points: List[Point] = [
Point(1, False),
Point(2, True),
Point(3, True),
Point(4, True),
Point(5, False),
Point(6, True),
Point(7, False),
Point(8, False),
Point(9, True),
Point(10, False),
]
def biased_purity_for_split(split_boundary: int) -> float:
left_points = [p for p in points if p.x < split_boundary]
right_points = [p for p in points if p.x >= split_boundary]
left_purity = sum(1.0 for p in left_points if p.y) / (sum([1.0 for _ in left_points]) + .000001)
right_purity = sum(1.0 for p in right_points if not p.y) / (sum([1.0 for _ in right_points]) + .000001)
left_weight = len(left_points) / (len(left_points) + len(right_points))
right_weight = len(right_points) / (len(left_points) + len(right_points))
return left_purity * left_weight + right_purity * right_weight
def find_best_biased_split() -> (int, float):
best_split: int = points[0].x
best_purity: float = 0.0
for p in points:
new_purity = biased_purity_for_split(p.x)
if new_purity > best_purity:
best_purity = new_purity
best_split = p.x
return best_split, best_purity
def impurity_for_split(split_boundary: int) -> float:
left_points = [p for p in points if p.x < split_boundary]
right_points = [p for p in points if p.x >= split_boundary]
left_impurity = 1.0 - (sum(1.0 for p in left_points if p.y) / (sum([1.0 for _ in left_points]) + .000001))**2 - (sum(1.0 for p in left_points if not p.y) / (sum([1.0 for _ in left_points]) + .000001))**2
right_impurity = 1.0 - (sum(1.0 for p in right_points if p.y) / (sum([1.0 for _ in right_points]) + .000001))**2 - (sum(1.0 for p in right_points if not p.y) / (sum([1.0 for _ in left_points]) + .000001))**2
left_weight = len(left_points) / (len(left_points) + len(right_points))
right_weight = len(right_points) / (len(left_points) + len(right_points))
return left_impurity* left_weight + right_impurity * right_weight
def find_best_impurity_split() -> (int, float):
best_split: int = points[0].x
best_purity: float = 0.0
for p in points:
new_purity = biased_purity_for_split(p.x)
if new_purity > best_purity:
best_purity = new_purity
best_split = p.x
return best_split, best_purity
print(find_best_biased_split())
print(find_best_impurity_split())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment