Skip to content

Instantly share code, notes, and snippets.

@suxue
Created April 20, 2014 08:06
Show Gist options
  • Save suxue/11108254 to your computer and use it in GitHub Desktop.
Save suxue/11108254 to your computer and use it in GitHub Desktop.
kmean classifier
_ = require("underscore")
assert = require('assert')
zip = (list) -> _.zip.apply(null, list)
sum = (list) -> _.reduce(list, ((memo, v) -> memo + v), 0)
mean = (list) -> sum(list) / list.length
distances =
euclidean : (p,q) ->Math.sqrt(sum _.map(_.zip(p, q), (v)->(v[0]-v[1])**2))
manhattan : (p,q) ->sum(_.map(_.zip(p, q), (v) -> Math.abs(v[0] - v[1])))
cosine : (p, q) ->
y = sum(_.map(_.zip(p,q), (v) -> v[0]*v[1]))
c = (x) -> Math.sqrt(sum(_.map(x, (i) -> i**2)))
r = Math.acos(y / (c(p) * c(q))) / Math.PI
init =
rand : (data, k) ->
index = _.map(data, (_, i) -> i)
for i in [1..k]
v = _.random(index.length-1)
x = index[v]
index = _.without(index, x)
data[x]
farthest : (data, k, dist) ->
data = _.clone(data)
r = _.random(data.length - 1)
for i in [1..k]
data = _.without(data, data[r])
_.reduce(dist(data[r], q) for q in data, ((memo, v, i) ->
if memo > v then memo else r = i; v), -Infinity)
data[r]
assert process.argv.length>=6, "provide arguments: k, dist, iter, init"
conf =
k : Number(process.argv[2])
dist : distances[process.argv[3]]
iter : Number(process.argv[4])
init : init[process.argv[5]]
kmeans = (data, k, dist, iterno, init) ->
centroids = init(data, k, dist)
assign = () ->
distances = ((dist(p, c) for p in data) for c in centroids)
groups = ([] for c in centroids)
_.each(zip(distances), (dist_vec, index) ->
min_dist = _.min(dist_vec)
min_index = _.indexOf(dist_vec, min_dist)
groups[min_index].push data[index]
data[index].dist = min_dist)
groups
for i in [1..iterno]
groups = _.reject(assign(), (g) -> g.length == 0)
newcent = ((mean(dim) for dim in (zip(group))) for group in groups)
if _.isEqual(newcent, centroids) then break else centroids = newcent
console.log "#{i} GROUPS:", groups.length, _.sortBy(g.length for g in groups)
console.log(centroids)
main = (content, sep) ->
readcsv = (c) ->
((Number(r) for r in line.split(sep || ',')) for line in c.split('\n'))
data = readcsv(content)
kmeans(data, conf.k, conf.dist, conf.iter, conf.init)
buf = []
process.stdin.resume(); process.stdin.setEncoding 'utf8'
process.stdin.on 'data', (chunk) -> buf.push chunk
process.stdin.on 'end', () -> main(buf.join('').trim(), process.argv[6])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment