Skip to content

Instantly share code, notes, and snippets.

@ajtulloch
Created November 3, 2013 21:22
Show Gist options
  • Save ajtulloch/7294986 to your computer and use it in GitHub Desktop.
Save ajtulloch/7294986 to your computer and use it in GitHub Desktop.
from collections import namedtuple
Example = namedtuple('Example', ['features', 'label'])
def loss(pairs):
"""
L^2 loss - sum of squared divergence of label from average
"""
if not pairs:
return 0.0
average_label = sum(l for _, l in pairs) / len(pairs)
return sum((l - average_label) ** 2 for _, l in pairs)
def get_best_split(examples, features):
best_feature, best_value, best_loss_reduction = \
0, 0.0, 0.0
for feature in features:
pairs = sorted(
[(e.features[feature], e.label) for e in examples])
for index, (value, label) in enumerate(pairs):
left, right = pairs[:index], pairs[index:]
current_loss_reduction = \
loss(left) + loss(right) - loss(pairs)
if current_loss_reduction < best_loss_reduction:
best_feature = feature
best_value = value
best_loss_reduction = current_loss_reduction
return (best_feature, best_value, best_loss_reduction)
"""
Dummy dataset where label == features[5] > 2.0
"""
print get_best_split(
examples=[
Example(features={5: 1.0, 6: 2.0}, label=False),
Example(features={5: 1.5, 6: 2.0}, label=False),
Example(features={5: 2.0, 6: 2.0}, label=True),
Example(features={5: 3.0, 6: 2.0}, label=True),
],
features=[5,6])
"""
>>> (5, 2.0, -2.0)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment