An example of K-means clustering in d3.js.
-
-
Save mayblue9/fb21145b21025a411150c158df979221 to your computer and use it in GitHub Desktop.
k-means + d3.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<meta charset="utf-8"> | |
<style> | |
body { | |
width: 1024px; | |
margin-top: 0; | |
margin: auto; | |
font-family: "Lato", "PT Serif", serif; | |
color: #222222; | |
padding: 0; | |
font-weight: 300; | |
line-height: 33px; | |
-webkit-font-smoothing: antialiased; | |
} | |
.axis path, | |
.axis line { | |
fill: none; | |
stroke: #000; | |
shape-rendering: crispEdges; | |
} | |
.dot { | |
stroke: #eee; | |
} | |
.centroid{ | |
stroke: #000; | |
fill-opacity: 0.8; | |
} | |
</style> | |
<body> | |
<script src="//d3js.org/d3.v3.min.js"></script> | |
<script src="model.js"></script> | |
<script> | |
var margin = {top: 100, right: 20, bottom: 80, left: 20}, | |
width = 960 - margin.left - margin.right, | |
height = 500 - margin.top - margin.bottom, | |
translate_speed = 1000; | |
var color = d3.scale.category10(); | |
var svg = d3.select("body").append("svg") | |
.attr("width", width + margin.left + margin.right) | |
.attr("height", height + margin.top + margin.bottom) | |
.append("g") | |
.attr("transform", "translate(" + margin.left + "," + margin.top + ")"); | |
var x = d3.scale.linear() | |
.range([0, width]) | |
.domain([0, 50]).nice(); | |
var y = d3.scale.linear() | |
.range([height, 0]) | |
.domain([0,50]).nice(); | |
svg.append('text') | |
.attr('x', width ) | |
.attr('y', height +50) | |
.attr('class', 'status') | |
.text('Clusters: ') | |
.style('text-anchor', 'end') | |
.style('font-size', '20') | |
svg.append('text') | |
.attr('x', width/2 ) | |
.attr('y', -50) | |
.attr('class', 'step') | |
.text('Initialize clusters.') | |
.style('text-anchor', 'middle') | |
.style('font-size', '36') | |
function initialize(){ | |
// Set the number of clusters | |
var num_clusters = Math.floor(Math.random()*4)+2; | |
// Generate a random sample of points | |
var samples = d3.range(0,40).map(function(d){ | |
return [ Math.floor(Math.random()*50), Math.floor(Math.random()*50)] | |
}) | |
// Update View | |
d3.select('.status').text('Clusters: '+num_clusters) | |
// Initialize the model | |
var k = new kmeans(num_clusters, samples) | |
// Plot this first round | |
plot(k) | |
return k | |
} | |
function plot(k){ | |
svg.selectAll('g').remove() | |
var g = svg.append('g'); | |
g.selectAll(".dot") | |
.data(k.data) | |
.enter().append("circle") | |
.attr("class", "dot") | |
.attr("r", 5) | |
.attr("cx", function(d) { return x(d.x); }) | |
.attr("cy", function(d) { return y(d.y); }) | |
.style("fill", function(d) { return color(d.clusterNumber); }) | |
g.selectAll(".centroids") | |
.data(k.centroids) | |
.enter().append('rect') | |
.attr('class', 'centroid') | |
.attr("x", function(d) { return x(d.x) - 2.5; }) | |
.attr("y", function(d) { return y(d.y) - 2.5; }) | |
.attr('width', 20) | |
.attr('height', 20) | |
.attr('rx', 1) | |
.attr('ry', 1) | |
.style("fill", function(d, i) { return color(i); }) | |
} | |
function step(k){ | |
k.recalculate_centroids() | |
k.update_clusters() | |
plot(k) | |
} | |
function run(){ | |
var k = initialize(), | |
max_count = 100, | |
updates = 0; | |
var go = setInterval(function(){ | |
if (k.isStillMoving == 1) { | |
d3.select('.step').text('Assign and Update ('+updates+').') | |
step(k) | |
updates +=1 | |
}else{ | |
clearInterval(go) | |
d3.selectAll('circles').transition().duration(translate_speed).remove() | |
run() | |
}; | |
}, translate_speed) | |
} | |
run() | |
d3.select(self.frameElement).style("height", (height+margin.top+margin.bottom) + "px"); | |
</script> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"use strict"; | |
class DataPoint{ | |
constructor(x,y){ | |
this.x = x | |
this.y = y | |
} | |
set_x(x){ | |
this.x = x | |
} | |
get get_x(){ | |
return this.x | |
} | |
set_y(y){ | |
this.y = y | |
} | |
get get_y(){ | |
return this.y | |
} | |
set_cluster(clusterNumber){ | |
this.clusterNumber = clusterNumber; | |
} | |
get get_cluster(){ | |
return this.clusterNumber | |
} | |
} | |
class Centroid { | |
constructor(x, y){ | |
this.x = x | |
this.y = y | |
} | |
set_x(x){ | |
this.x = x | |
} | |
get get_x(){ | |
return this.x | |
} | |
set_y(y){ | |
this.y = y | |
} | |
get get_y(){ | |
return this.y | |
} | |
} | |
class kmeans{ | |
constructor(num_cluster, samples){ | |
this.num_cluster = num_cluster; | |
this.samples = samples | |
this.total_data = samples.length; | |
this.data = []; | |
this.centroids = []; | |
this.isStillMoving = 1; | |
this.initialize_centroids() | |
this.initialize_datapoints() | |
} | |
initialize_centroids(){ | |
// Set the centoid coordinates to match the data points furthest from each other. | |
// In this example, (1.0, 1.0) and (5.0, 7.0) | |
for (var i = 0; i < this.num_cluster; i++) { | |
var pos = Math.floor(Math.random()*this.total_data) | |
var c = new Centroid(this.samples[pos][0], this.samples[pos][1]) | |
this.centroids.push(c) | |
} | |
} | |
initialize_datapoints(){ | |
// DataPoint objects' x and y values are taken from the SAMPLE array. | |
// The DataPoints associated with LOWEST_SAMPLE_POINT and HIGHEST_SAMPLE_POINT are initially | |
// assigned to the clusters matching the LOWEST_SAMPLE_POINT and HIGHEST_SAMPLE_POINT centroids. | |
for (var i = 0; i < this.total_data; i++) { | |
var newPoint = new DataPoint(this.samples[i][0], this.samples[i][1]) | |
if (i <= this.num_cluster) { | |
newPoint.set_cluster(i) | |
} else{ | |
newPoint.set_cluster(NaN) | |
}; | |
this.data.push(newPoint) | |
}; | |
} | |
get_distance(dx, dy, cx, cy){ | |
// Calculate Euclidean distance. | |
return Math.sqrt(Math.pow((cy - dy), 2) + Math.pow((cx - dx), 2)) | |
} | |
recalculate_centroids(){ | |
this.isStillMoving = 0; | |
for (var j = 0; j < this.num_cluster; j++) { | |
var totalX = 0, | |
totalY = 0, | |
totalInCluster = 0, | |
current_position = [this.centroids[j].x, this.centroids[j].y]; | |
for (var k = 0; k < this.data.length; k++) { | |
if (this.data[k].get_cluster == j) { | |
totalX += this.data[k].get_x | |
totalY += this.data[k].get_y | |
totalInCluster += 1 | |
}; | |
} | |
if(totalInCluster > 0){ | |
this.centroids[j].set_x(totalX / totalInCluster) | |
this.centroids[j].set_y(totalY / totalInCluster) | |
} | |
if(this.centroids[j].x != current_position[0] || this.centroids[j].y != current_position[1]){ | |
this.isStillMoving = 1; | |
} | |
} | |
} | |
update_clusters(){ | |
for (var i = 0; i < this.total_data; i++) { | |
var bestMinimum = 1000000, | |
currentCluster = 0 | |
for (var j = 0; j < this.num_cluster; j++) { | |
var distance = this.get_distance(this.data[i].get_x, this.data[i].get_y, this.centroids[j].get_x, this.centroids[j].get_y) | |
if (distance < bestMinimum) { | |
bestMinimum = distance; | |
currentCluster = j; | |
}; | |
} | |
this.data[i].set_cluster(currentCluster); | |
}; | |
} | |
fit(max_count){ | |
var max_count = max_count || 100; | |
var count = 0; | |
while(this.isStillMoving == 1 && count < max_count){ | |
this.recalculate_centroids() | |
this.update_clusters() | |
count +=1; | |
} | |
} | |
log(){ | |
for (var i = 0; i < this.num_cluster; i++) { | |
console.log("Cluster ", i, " includes:") | |
for (var j = 0; j < this.total_data; j++) { | |
if(this.data[j].get_cluster == i){ | |
console.log("(", this.data[j].get_x, ", ", this.data[j].get_y, ")") | |
} | |
console.log() | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment