Skip to content

Instantly share code, notes, and snippets.

@hardbyte
Created July 14, 2014 05:59
Show Gist options
  • Save hardbyte/ded34566f6fb704264b4 to your computer and use it in GitHub Desktop.
Save hardbyte/ded34566f6fb704264b4 to your computer and use it in GitHub Desktop.
K-means with D3js
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- Latest compiled and minified CSS -->
<link rel="stylesheet" href="//maxcdn.bootstrapcdn.com/bootstrap/3.2.0/css/bootstrap.min.css">
<!-- Optional theme -->
<link rel="stylesheet" href="//maxcdn.bootstrapcdn.com/bootstrap/3.2.0/css/bootstrap-theme.min.css">
<style>
path {
stroke: #a51314;
fill: none;
}
circle.data {
fill: steelblue;
pointer-events: none;
}
circle.means {
fill: red;
opacity: 0.3;
}
line.ax {
stroke: grey;
stroke-width: "2px";
}
</style>
</head>
<body>
<div class="container-fluid">
<div class="row">
<div id="vis"></div>
</div>
</div>
<script src="https://ajax.googleapis.com/ajax/libs/jquery/1.11.1/jquery.min.js"></script>
<script src="//maxcdn.bootstrapcdn.com/bootstrap/3.2.0/js/bootstrap.min.js"></script>
<script src="http://d3js.org/d3.v3.min.js"></script>
<script src="http://numericjs.com/lib/numeric-1.2.6.min.js"></script>
<script>
var margin = {top: 20, right: 20, bottom: 30, left: 50},
width = 960 - margin.left - margin.right,
height = 500 - margin.top - margin.bottom;
var X = d3.scale.linear()
.range([0, width]);
var Y = d3.scale.linear()
.range([height, 0]);
var voronoi = d3.geom.voronoi()
.clipExtent([[0, 0], [width, height]]);
var svg = d3.select("#vis").append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
var k = 4,
numSamplesPerFrame = 10,
numSamplesPerCluster = 200;
var data;
var xdata = [],
ydata = [],
cdata = [];
var x_means = [],
y_means = [];
// Draw the normal axes
var components = svg.selectAll("line.ax")
.data([
[[0.0, 0.5], [1, 0.5], "X"],
[[0.5, 0.0], [0.5, 1], "Y"],
], function(d, i){return d[2];});
components.exit().remove();
components.enter().append('line')
.attr('class', 'ax')
.attr('x1', function (d) { return X(d[0][0]); })
.attr('y1', function (d) { return Y(d[0][1]); })
.attr('x2', function (d) { return X(d[1][0]); })
.attr('y2', function (d) { return Y(d[1][1]); });
function rnd(mean, std){
var r = 0;
for (var i = 0; i < 10; i++) {
r += Math.random() * 2 - 1
}
return r * std + mean;
}
function kmeans(){
// Step 1, choose k random starting positions
for(i = 0; i < k; i++){
x_means[i] = Math.random();
y_means[i] = Math.random();
}
plotMeans();
}
var numSteps = 0;
function step() {
var path = svg.selectAll("path");
function redraw() {
var d = [];
for (var i = 0; i < k; i++) {
d.push([X(x_means[i]), Y(y_means[i])]);
}
var vd = voronoi(d);
var v = path
.data(vd, polygon);
v.exit().remove();
v.enter()
.append("path");
v
.attr("d", polygon).order()
.transition().ease("linear").duration(200);
//;
}
function polygon(d) {
return "M" + d.join("L") + "Z";
}
redraw();
// For each point calculate the nearest mean
// TODO partitioning the observations according to the Voronoi diagram generated by the means
for (var i = 0; i < xdata.length; i++) {
var nearestDistance = 9999999999;
for (var j = 0; j < k; j++) {
var distance = Math.pow( xdata[i] - x_means[j], 2) + Math.pow( ydata[i] - y_means[j], 2);
if(distance < nearestDistance){
nearestDistance = distance;
cdata[i] = j;
}
}
}
// For each mean calculate the centroid of all points
var keepGoing = (++numSteps < 100);
for (var j = 0; j < k; j++) {
var n = 0;
var totalx = 0, totaly = 0;
for (var i = 0; i < xdata.length; i++) {
if (cdata[i] == j) {
n += 1;
totalx += xdata[i];
totaly += ydata[i];
}
}
if(n === 0){
// Not part of any clusters
n = 1;
totalx = Math.random(), totaly = Math.random();
}
if (totalx / n != x_means[j] || totaly / n != y_means[j]) {
x_means[j] = totalx / n;
y_means[j] = totaly / n;
keepGoing = true;
}
}
plotMeans();
return !keepGoing;
}
function plotData(i){
var xycoords = numeric.transpose([xdata, ydata]).slice(0, i);
var circle = svg.selectAll("circle.data")
.data(xycoords);
circle.enter().append("circle")
.attr('class', 'data')
.attr("r", 1);
circle
.attr("cx", function(d, i) { return X(d[0]); })
.attr("cy", function(d, i){return Y(d[1]);});
circle.exit().remove();
}
function plotMeans(){
var circle = svg.selectAll("circle.means")
.data(numeric.transpose([x_means, y_means]));
circle.enter().append("circle")
.attr('class', 'means')
.attr("r", 10);
circle
.transition().ease("linear").duration(200)
.attr("cx", function(d, i) { return X(d[0]); })
.attr("cy", function(d, i){return Y(d[1]);});
circle.exit().remove();
}
function lim(val, min, max){
if(val < min){
return min;
}
if(val > max){
return max;
}
return val;
}
function createData() {
for(var cluster = 0; cluster < k; ++cluster) {
var amean = rnd(0.5, 0.1);
var bmean = rnd(0.5, 0.1);
var astd = rnd(0.03, 0.01);
//var bstd = rnd(0.02, 0.02);
var ax = rnd(0.98, 0.01);
var ay = rnd(0.02, 0.01);
var bx = 1 - ax;
var by = 1 - ay;
for (var i = 0; i < numSamplesPerCluster; ++i) {
var a = rnd(amean, astd),
b = rnd(bmean, astd);
var x = lim(ax * a + bx * b, 0, 1),
y = lim(ay * a + by * b, 0, 1);
xdata.push(x);
ydata.push(y);
}
}
var numFrames = 0;
d3.timer(function () {
plotData(numSamplesPerFrame * numFrames);
if( (++numFrames) * numSamplesPerFrame > k * numSamplesPerCluster){
kmeans();
d3.timer(function () {
return step();
});
return true;
}
});
}
createData();
</script>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment