Created
November 3, 2013 17:32
-
-
Save ajtulloch/7292671 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
Example = namedtuple('Example', ['features', 'label']) | |
def loss(pairs): | |
""" | |
L^2 loss - sum of squared divergece of label from average label | |
""" | |
if not pairs: | |
return 0.0 | |
labels = [label for _, label in pairs] | |
average_label = sum(labels) / len(labels) | |
return sum((label - average_label) ** 2 for label in labels) | |
def get_best_split(examples, features): | |
best_feature, best_value, best_loss_reduction = 0, 0.0, 0.0 | |
for feature in features: | |
pairs = zip( | |
map(lambda e: e.features[feature], examples), | |
map(lambda e: e.label, examples)) | |
sorted(pairs) # Sort by first index | |
for index, (value, label) in enumerate(pairs): | |
left, right = pairs[:index], pairs[index:] | |
current_loss_reduction = loss(left) + loss(right) - loss(pairs) | |
if current_loss_reduction < best_loss_reduction: | |
best_feature = feature | |
best_value = value | |
best_loss_reduction = current_loss_reduction | |
return (best_feature, best_value, best_loss_reduction) | |
""" | |
Dummy dataset where label == features[5] > 2.0 | |
""" | |
print get_best_split( | |
examples=[ | |
Example(features={5: 1.0, 6: 2.0}, label=False), | |
Example(features={5: 1.5, 6: 2.0}, label=False), | |
Example(features={5: 2.0, 6: 2.0}, label=True), | |
Example(features={5: 3.0, 6: 2.0}, label=True), | |
], | |
features=[5,6]) | |
""" | |
>>> (5, 2.0, -2.0) | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment