Created
July 14, 2014 05:59
-
-
Save hardbyte/ded34566f6fb704264b4 to your computer and use it in GitHub Desktop.
K-means with D3js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="utf-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1"> | |
<!-- Latest compiled and minified CSS --> | |
<link rel="stylesheet" href="//maxcdn.bootstrapcdn.com/bootstrap/3.2.0/css/bootstrap.min.css"> | |
<!-- Optional theme --> | |
<link rel="stylesheet" href="//maxcdn.bootstrapcdn.com/bootstrap/3.2.0/css/bootstrap-theme.min.css"> | |
<style> | |
path { | |
stroke: #a51314; | |
fill: none; | |
} | |
circle.data { | |
fill: steelblue; | |
pointer-events: none; | |
} | |
circle.means { | |
fill: red; | |
opacity: 0.3; | |
} | |
line.ax { | |
stroke: grey; | |
stroke-width: "2px"; | |
} | |
</style> | |
</head> | |
<body> | |
<div class="container-fluid"> | |
<div class="row"> | |
<div id="vis"></div> | |
</div> | |
</div> | |
<script src="https://ajax.googleapis.com/ajax/libs/jquery/1.11.1/jquery.min.js"></script> | |
<script src="//maxcdn.bootstrapcdn.com/bootstrap/3.2.0/js/bootstrap.min.js"></script> | |
<script src="http://d3js.org/d3.v3.min.js"></script> | |
<script src="http://numericjs.com/lib/numeric-1.2.6.min.js"></script> | |
<script> | |
var margin = {top: 20, right: 20, bottom: 30, left: 50}, | |
width = 960 - margin.left - margin.right, | |
height = 500 - margin.top - margin.bottom; | |
var X = d3.scale.linear() | |
.range([0, width]); | |
var Y = d3.scale.linear() | |
.range([height, 0]); | |
var voronoi = d3.geom.voronoi() | |
.clipExtent([[0, 0], [width, height]]); | |
var svg = d3.select("#vis").append("svg") | |
.attr("width", width + margin.left + margin.right) | |
.attr("height", height + margin.top + margin.bottom) | |
.append("g") | |
.attr("transform", "translate(" + margin.left + "," + margin.top + ")"); | |
var k = 4, | |
numSamplesPerFrame = 10, | |
numSamplesPerCluster = 200; | |
var data; | |
var xdata = [], | |
ydata = [], | |
cdata = []; | |
var x_means = [], | |
y_means = []; | |
// Draw the normal axes | |
var components = svg.selectAll("line.ax") | |
.data([ | |
[[0.0, 0.5], [1, 0.5], "X"], | |
[[0.5, 0.0], [0.5, 1], "Y"], | |
], function(d, i){return d[2];}); | |
components.exit().remove(); | |
components.enter().append('line') | |
.attr('class', 'ax') | |
.attr('x1', function (d) { return X(d[0][0]); }) | |
.attr('y1', function (d) { return Y(d[0][1]); }) | |
.attr('x2', function (d) { return X(d[1][0]); }) | |
.attr('y2', function (d) { return Y(d[1][1]); }); | |
function rnd(mean, std){ | |
var r = 0; | |
for (var i = 0; i < 10; i++) { | |
r += Math.random() * 2 - 1 | |
} | |
return r * std + mean; | |
} | |
function kmeans(){ | |
// Step 1, choose k random starting positions | |
for(i = 0; i < k; i++){ | |
x_means[i] = Math.random(); | |
y_means[i] = Math.random(); | |
} | |
plotMeans(); | |
} | |
var numSteps = 0; | |
function step() { | |
var path = svg.selectAll("path"); | |
function redraw() { | |
var d = []; | |
for (var i = 0; i < k; i++) { | |
d.push([X(x_means[i]), Y(y_means[i])]); | |
} | |
var vd = voronoi(d); | |
var v = path | |
.data(vd, polygon); | |
v.exit().remove(); | |
v.enter() | |
.append("path"); | |
v | |
.attr("d", polygon).order() | |
.transition().ease("linear").duration(200); | |
//; | |
} | |
function polygon(d) { | |
return "M" + d.join("L") + "Z"; | |
} | |
redraw(); | |
// For each point calculate the nearest mean | |
// TODO partitioning the observations according to the Voronoi diagram generated by the means | |
for (var i = 0; i < xdata.length; i++) { | |
var nearestDistance = 9999999999; | |
for (var j = 0; j < k; j++) { | |
var distance = Math.pow( xdata[i] - x_means[j], 2) + Math.pow( ydata[i] - y_means[j], 2); | |
if(distance < nearestDistance){ | |
nearestDistance = distance; | |
cdata[i] = j; | |
} | |
} | |
} | |
// For each mean calculate the centroid of all points | |
var keepGoing = (++numSteps < 100); | |
for (var j = 0; j < k; j++) { | |
var n = 0; | |
var totalx = 0, totaly = 0; | |
for (var i = 0; i < xdata.length; i++) { | |
if (cdata[i] == j) { | |
n += 1; | |
totalx += xdata[i]; | |
totaly += ydata[i]; | |
} | |
} | |
if(n === 0){ | |
// Not part of any clusters | |
n = 1; | |
totalx = Math.random(), totaly = Math.random(); | |
} | |
if (totalx / n != x_means[j] || totaly / n != y_means[j]) { | |
x_means[j] = totalx / n; | |
y_means[j] = totaly / n; | |
keepGoing = true; | |
} | |
} | |
plotMeans(); | |
return !keepGoing; | |
} | |
function plotData(i){ | |
var xycoords = numeric.transpose([xdata, ydata]).slice(0, i); | |
var circle = svg.selectAll("circle.data") | |
.data(xycoords); | |
circle.enter().append("circle") | |
.attr('class', 'data') | |
.attr("r", 1); | |
circle | |
.attr("cx", function(d, i) { return X(d[0]); }) | |
.attr("cy", function(d, i){return Y(d[1]);}); | |
circle.exit().remove(); | |
} | |
function plotMeans(){ | |
var circle = svg.selectAll("circle.means") | |
.data(numeric.transpose([x_means, y_means])); | |
circle.enter().append("circle") | |
.attr('class', 'means') | |
.attr("r", 10); | |
circle | |
.transition().ease("linear").duration(200) | |
.attr("cx", function(d, i) { return X(d[0]); }) | |
.attr("cy", function(d, i){return Y(d[1]);}); | |
circle.exit().remove(); | |
} | |
function lim(val, min, max){ | |
if(val < min){ | |
return min; | |
} | |
if(val > max){ | |
return max; | |
} | |
return val; | |
} | |
function createData() { | |
for(var cluster = 0; cluster < k; ++cluster) { | |
var amean = rnd(0.5, 0.1); | |
var bmean = rnd(0.5, 0.1); | |
var astd = rnd(0.03, 0.01); | |
//var bstd = rnd(0.02, 0.02); | |
var ax = rnd(0.98, 0.01); | |
var ay = rnd(0.02, 0.01); | |
var bx = 1 - ax; | |
var by = 1 - ay; | |
for (var i = 0; i < numSamplesPerCluster; ++i) { | |
var a = rnd(amean, astd), | |
b = rnd(bmean, astd); | |
var x = lim(ax * a + bx * b, 0, 1), | |
y = lim(ay * a + by * b, 0, 1); | |
xdata.push(x); | |
ydata.push(y); | |
} | |
} | |
var numFrames = 0; | |
d3.timer(function () { | |
plotData(numSamplesPerFrame * numFrames); | |
if( (++numFrames) * numSamplesPerFrame > k * numSamplesPerCluster){ | |
kmeans(); | |
d3.timer(function () { | |
return step(); | |
}); | |
return true; | |
} | |
}); | |
} | |
createData(); | |
</script> | |
</body> | |
</html> | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment