Created
December 27, 2012 22:57
-
-
Save colinpollock/4392898 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import itertools | |
import math | |
import pprint | |
import random | |
import sys | |
INF = float('inf') | |
def closest(num, means): | |
"""Return the `mean` closest to `num`.""" | |
assert means | |
return min((abs(mn - num), mn) for mn in means)[1] | |
def mean(nums): | |
return sum(nums) / len(nums) | |
def variance(nums): | |
mn = mean(nums) | |
return sum((mn - num) ** 2 for num in nums) / len(nums) | |
def std(nums): | |
return math.sqrt(variance(nums)) | |
def group(means, nums): | |
"""Return a dict from each `mean` to the list of nums closest to it.""" | |
groups = {mn: [] for mn in means} | |
for num in nums: | |
groups[closest(num, means)].append(num) | |
return groups | |
def _cluster(nums, k): | |
# Randomly choose `k` means from `nums` to start with | |
means = sorted(random.sample(nums, k)) | |
# Loop until the means don't change | |
while True: | |
groups = group(means, nums) | |
new_means = sorted(mean(nums) for nums in groups.itervalues()) | |
if new_means == means: | |
break | |
else: | |
means[:] = new_means | |
return group(means, nums) | |
def cluster(nums, k, iters=1): | |
"""Cluster `nums` into `k` buckets. | |
Args: | |
nums: a list of numbers | |
k: the number of clusters | |
iters: the number of times to times to repeat clustering. The clustering | |
with the minimum variance is returned. | |
""" | |
assert iters > 0 | |
groups = [_cluster(nums, k) for _ in xrange(iters)] | |
def sum_variance_for_group(group): | |
return sum(variance(nums) for nums in group.values()) | |
return min(groups, key=sum_variance_for_group) | |
def main(args): | |
nums = [int(n) for n in args[0].split(',')] | |
num_groups = int(args[1]) | |
num_iters = int(args[2]) | |
groups = cluster(nums, num_groups, num_iters) | |
for mean, nums in groups.items(): | |
print 'cluster=%s, mean=%.2f, std=%.2f' % (nums, mean, std(nums)) | |
if __name__ == '__main__': | |
main(sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment