Skip to content

Instantly share code, notes, and snippets.

@gangtao
Last active September 8, 2017 06:47
Show Gist options
  • Save gangtao/02e4b9b5b592cbbbb55c3dd7f6d744ed to your computer and use it in GitHub Desktop.
Save gangtao/02e4b9b5b592cbbbb55c3dd7f6d744ed to your computer and use it in GitHub Desktop.
ML Explained KNN
<div class="container">
<div class="row">
<div class="col-sm-6">
<div id="chart"></div>
</div>
<div class="col-sm-6">
<div class="row">
<button type="button" class="btn btn-xs" id="new_category_button">New Category Data</button>
<span id="category_info"></span>
<button type="button" class="btn btn-xs" id="classify_button">Predict</button>
<button type="button" class="btn btn-xs" id="clean_button">Clear</button>
</div>
<div class="row">
<label for="k_input">K</label>
<input type="text" id="k_input" placeholder="3">
</div>
</div>
</div>
</div>
var size = 400; //The size of the canvas
var margin_size = 50;
var point_size = 8;
var colors = d3.scaleOrdinal(d3.schemeCategory10);
var domain_max = 100;
var data = [];
var current_category = undefined;
var K = 3;
var total_category = 0;
function drawCircle(container, p, r, color) {
var circle = container
.append("circle")
.attr("cx", p.x)
.attr("cy", p.y)
.attr("r", r)
.classed("circle", true);
if (color) {
circle.style("fill", color);
}
return circle;
}
function drawLine(container, p1, p2) {
var line = container
.append("line")
.attr("x1", p1.x)
.attr("y1", p1.y)
.attr("x2", p2.x)
.attr("y2", p2.y)
.classed("line", true);
return line;
}
function gen_new_category() {
if (current_category == undefined) {
current_category = 0;
} else {
current_category = current_category + 1;
}
total_category = total_category + 1;
d3
.select("#category_info")
.text("category" + current_category)
.style("color", colors(current_category));
}
function clean() {
data = [];
current_category = undefined;
total_category = 0;
d3.select("#category_info").text("");
$(".circle").remove();
$(".line").remove();
}
function get_distance(p1, p2) {
var d = Math.sqrt(Math.pow(p2.x - p1.x, 2) + Math.pow(p2.y - p1.y, 2));
return d;
}
function get_knn(current, points, k) {
var dists = [];
points.map(function(item) {
var result = {};
result.p = item;
result.d = get_distance(current, item.coordinates);
dists.push(result);
});
dists.sort(function(a, b) {
return a.d - b.d;
});
return dists.slice(0, k);
}
function get_vote(knn) {
var result = [];
var i = 0,
length = total_category;
for (; i < length; i++) {
result[i] = {};
result[i].label = i;
result[i].value = 0;
}
knn.map(function(item) {
result[item.p.category].value++;
});
result.sort(function(a, b) {
return b.value - a.value;
});
return result[0].label;
}
$(function() {
var margin = {
top: margin_size,
right: margin_size,
bottom: margin_size,
left: margin_size
},
width = size - margin.left - margin.right,
height = size - margin.top - margin.bottom;
var root = d3
.select("#chart")
.append("svg")
.attr("width", size)
.attr("height", size);
var g = root
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
var layer1 = root.append("g").attr("id", "layer1");
//Draw Axis
var xScale = d3.scaleLinear().rangeRound([0, width]);
var yScale = d3.scaleLinear().rangeRound([height, 0]);
xScale.domain([0, domain_max]);
yScale.domain([0, domain_max]);
g
.append("g")
.attr("transform", "translate(0," + height + ")")
.call(d3.axisBottom(xScale));
g.append("g").call(d3.axisLeft(yScale));
// Generate Data
$("#new_category_button").click(function() {
//clean();
gen_new_category();
root.on("click", function() {
var coords = d3.mouse(this);
var mapped_coords = [coords[0] - margin_size, coords[1] - margin_size];
// convert to data domain
var newData = {
x: Math.round(xScale.invert(mapped_coords[0])), // Takes the pixel number to convert to number
y: Math.round(yScale.invert(mapped_coords[1]))
};
newData.coordinates = {x:coords[0],y:coords[1]};
if (current_category !== undefined) {
drawCircle(
g,
{ x: mapped_coords[0], y: mapped_coords[1] },
point_size,
colors(current_category)
);
newData.category = current_category;
data.push(newData);
}
});
});
$("#clean_button").click(function() {
clean();
});
$("#classify_button").click(function() {
root.on("mousemove", function(d) {
$("#layer1").empty();
var coords = d3.mouse(this);
var mapped_coords = [coords[0] - margin_size, coords[1] - margin_size];
var current_data_point = {};
current_data_point.x = Math.round(xScale.invert(mapped_coords[0]));
current_data_point.y = Math.round(yScale.invert(mapped_coords[1]));
var current_point = {};
current_point.x = coords[0];
current_point.y = coords[1];
if (
current_data_point.x < 0 ||
current_data_point.x > domain_max ||
current_data_point.y < 0 ||
current_data_point.y > domain_max
) {
return;
}
var knn = get_knn(current_point, data, K);
var d_max = knn[knn.length - 1].d;
var vote = get_vote(knn);
knn.map(function(item) {
var line = drawLine(layer1,item.p.coordinates, current_point);
});
var range_circle = drawCircle(layer1, current_point, d_max + 5);
range_circle.classed("range",true);
range_circle.classed("circle",false);
var predict_circle = drawCircle(layer1, current_point, 8, colors(vote));
});
});
$("#k_input").change(function(){
K = $(this).val();
})
});
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/4.9.1/d3.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
body {
background-color: #555;
margin-top: 10px;
}
input {
width: 30px;
}
#chart {
background-color: #555;
}
.row {
padding-top: 2px;
}
.circle {
stroke: #000;
stroke-width: 1px;
fill: #fa6900;
fill-opacity: 1;
}
.range {
stroke: #000;
stroke-width: 0px;
fill: #ccc;
fill-opacity: 0.2;
}
.line {
stroke: #ccc;
stroke-width: 3px;
}
<link href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.3.7/css/bootstrap.min.css" rel="stylesheet" />
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment