Skip to content

Instantly share code, notes, and snippets.

@mtaptich
Last active August 1, 2016 01:37
Show Gist options
  • Save mtaptich/7af7a88b73496dc991b3 to your computer and use it in GitHub Desktop.
Save mtaptich/7af7a88b73496dc991b3 to your computer and use it in GitHub Desktop.
k-means + d3.js

An example of K-means clustering in d3.js.

<!DOCTYPE html>
<meta charset="utf-8">
<style>
body {
width: 1024px;
margin-top: 0;
margin: auto;
font-family: "Lato", "PT Serif", serif;
color: #222222;
padding: 0;
font-weight: 300;
line-height: 33px;
-webkit-font-smoothing: antialiased;
}
.axis path,
.axis line {
fill: none;
stroke: #000;
shape-rendering: crispEdges;
}
.dot {
stroke: #eee;
}
.centroid{
stroke: #000;
fill-opacity: 0.8;
}
</style>
<body>
<script src="//d3js.org/d3.v3.min.js"></script>
<script src="model.js"></script>
<script>
var margin = {top: 100, right: 20, bottom: 80, left: 20},
width = 960 - margin.left - margin.right,
height = 500 - margin.top - margin.bottom,
translate_speed = 1000;
var color = d3.scale.category10();
var svg = d3.select("body").append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
var x = d3.scale.linear()
.range([0, width])
.domain([0, 50]).nice();
var y = d3.scale.linear()
.range([height, 0])
.domain([0,50]).nice();
svg.append('text')
.attr('x', width )
.attr('y', height +50)
.attr('class', 'status')
.text('Clusters: ')
.style('text-anchor', 'end')
.style('font-size', '20')
svg.append('text')
.attr('x', width/2 )
.attr('y', -50)
.attr('class', 'step')
.text('Initialize clusters.')
.style('text-anchor', 'middle')
.style('font-size', '36')
function initialize(){
// Set the number of clusters
var num_clusters = Math.floor(Math.random()*4)+2;
// Generate a random sample of points
var samples = d3.range(0,40).map(function(d){
return [ Math.floor(Math.random()*50), Math.floor(Math.random()*50)]
})
// Update View
d3.select('.status').text('Clusters: '+num_clusters)
// Initialize the model
var k = new kmeans(num_clusters, samples)
// Plot this first round
plot(k)
return k
}
function plot(k){
svg.selectAll('g').remove()
var g = svg.append('g');
g.selectAll(".dot")
.data(k.data)
.enter().append("circle")
.attr("class", "dot")
.attr("r", 5)
.attr("cx", function(d) { return x(d.x); })
.attr("cy", function(d) { return y(d.y); })
.style("fill", function(d) { return color(d.clusterNumber); })
g.selectAll(".centroids")
.data(k.centroids)
.enter().append('rect')
.attr('class', 'centroid')
.attr("x", function(d) { return x(d.x) - 2.5; })
.attr("y", function(d) { return y(d.y) - 2.5; })
.attr('width', 20)
.attr('height', 20)
.attr('rx', 1)
.attr('ry', 1)
.style("fill", function(d, i) { return color(i); })
}
function step(k){
k.recalculate_centroids()
k.update_clusters()
plot(k)
}
function run(){
var k = initialize(),
max_count = 100,
updates = 0;
var go = setInterval(function(){
if (k.isStillMoving == 1) {
d3.select('.step').text('Assign and Update ('+updates+').')
step(k)
updates +=1
}else{
clearInterval(go)
d3.selectAll('circles').transition().duration(translate_speed).remove()
run()
};
}, translate_speed)
}
run()
d3.select(self.frameElement).style("height", (height+margin.top+margin.bottom) + "px");
</script>
"use strict";
class DataPoint{
constructor(x,y){
this.x = x
this.y = y
}
set_x(x){
this.x = x
}
get get_x(){
return this.x
}
set_y(y){
this.y = y
}
get get_y(){
return this.y
}
set_cluster(clusterNumber){
this.clusterNumber = clusterNumber;
}
get get_cluster(){
return this.clusterNumber
}
}
class Centroid {
constructor(x, y){
this.x = x
this.y = y
}
set_x(x){
this.x = x
}
get get_x(){
return this.x
}
set_y(y){
this.y = y
}
get get_y(){
return this.y
}
}
class kmeans{
constructor(num_cluster, samples){
this.num_cluster = num_cluster;
this.samples = samples
this.total_data = samples.length;
this.data = [];
this.centroids = [];
this.isStillMoving = 1;
this.initialize_centroids()
this.initialize_datapoints()
}
initialize_centroids(){
// Set the centoid coordinates to match the data points furthest from each other.
// In this example, (1.0, 1.0) and (5.0, 7.0)
for (var i = 0; i < this.num_cluster; i++) {
var pos = Math.floor(Math.random()*this.total_data)
var c = new Centroid(this.samples[pos][0], this.samples[pos][1])
this.centroids.push(c)
}
}
initialize_datapoints(){
// DataPoint objects' x and y values are taken from the SAMPLE array.
// The DataPoints associated with LOWEST_SAMPLE_POINT and HIGHEST_SAMPLE_POINT are initially
// assigned to the clusters matching the LOWEST_SAMPLE_POINT and HIGHEST_SAMPLE_POINT centroids.
for (var i = 0; i < this.total_data; i++) {
var newPoint = new DataPoint(this.samples[i][0], this.samples[i][1])
if (i <= this.num_cluster) {
newPoint.set_cluster(i)
} else{
newPoint.set_cluster(NaN)
};
this.data.push(newPoint)
};
}
get_distance(dx, dy, cx, cy){
// Calculate Euclidean distance.
return Math.sqrt(Math.pow((cy - dy), 2) + Math.pow((cx - dx), 2))
}
recalculate_centroids(){
this.isStillMoving = 0;
for (var j = 0; j < this.num_cluster; j++) {
var totalX = 0,
totalY = 0,
totalInCluster = 0,
current_position = [this.centroids[j].x, this.centroids[j].y];
for (var k = 0; k < this.data.length; k++) {
if (this.data[k].get_cluster == j) {
totalX += this.data[k].get_x
totalY += this.data[k].get_y
totalInCluster += 1
};
}
if(totalInCluster > 0){
this.centroids[j].set_x(totalX / totalInCluster)
this.centroids[j].set_y(totalY / totalInCluster)
}
if(this.centroids[j].x != current_position[0] || this.centroids[j].y != current_position[1]){
this.isStillMoving = 1;
}
}
}
update_clusters(){
for (var i = 0; i < this.total_data; i++) {
var bestMinimum = 1000000,
currentCluster = 0
for (var j = 0; j < this.num_cluster; j++) {
var distance = this.get_distance(this.data[i].get_x, this.data[i].get_y, this.centroids[j].get_x, this.centroids[j].get_y)
if (distance < bestMinimum) {
bestMinimum = distance;
currentCluster = j;
};
}
this.data[i].set_cluster(currentCluster);
};
}
fit(max_count){
var max_count = max_count || 100;
var count = 0;
while(this.isStillMoving == 1 && count < max_count){
this.recalculate_centroids()
this.update_clusters()
count +=1;
}
}
log(){
for (var i = 0; i < this.num_cluster; i++) {
console.log("Cluster ", i, " includes:")
for (var j = 0; j < this.total_data; j++) {
if(this.data[j].get_cluster == i){
console.log("(", this.data[j].get_x, ", ", this.data[j].get_y, ")")
}
console.log()
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment