Skip to content

Instantly share code, notes, and snippets.

@mik30s
Last active June 8, 2018 02:13
Show Gist options
  • Save mik30s/dea6a0762d335ebfc05021d13b9d9daf to your computer and use it in GitHub Desktop.
Save mik30s/dea6a0762d335ebfc05021d13b9d9daf to your computer and use it in GitHub Desktop.
K Nearest Neighbors in Javascript
<div id="knn-react-app" class="embedded-app" style="text-align: center;">
<button id="knn-1-predict-btn"> <i class=""></i> Predict</button>
<input type="checkbox" id="knn-1-scale-btn" />
<label for="subscribeNews">Scale?</label>
<input id="knn-1-kvalue" placeholder="# of neighbours eg. 3"/>
<input id="knn-1-point-input" placeholder="weight,color,seeds eg. 371,3,1" />
<i class="fas fa-equals"></i>
<img width="25" height="25" id="knn-1-class-img" />
</div>
<div id="knn-vis1" style="width: 100%; height: 520px;"></div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.17/d3.min.js"></script>
<script type="text/javascript" async>
class KNearestNeighbours {
constructor(d, k, metric, args) {
this.k = (k % 2 == 0) ? k+1 : k;
this.metric = metric
this.args = args
this.data = d.slice()
}
predict(examples) {
//console.log("Predicing on", this.data)
let classes = []
for (let point of examples) {
//console.log("Working on point", point)
let xsorted = this.data.sort((a, b) => {
let flength = a.length -1
//console.log("Comparing: ", a.slice(0,flength), "and", b.slice(0,flength))
let r = Math.ceil(this.metric(point,a.slice(0,flength),...this.args))
let l = Math.ceil(this.metric(point,b.slice(0,flength),...this.args))
//console.log("r, l", r, l)
return r - l
})
let kclosests = xsorted.slice(0,this.k)
//console.log(this.k,"closests", kclosests)
let counts = {}
let maxcnt = 0
let maxelem = null
for (var p of kclosests) {
let i = p.length-1
counts[p[i]] = (counts[p[i]] === undefined) ? 1 : ++counts[p[i]]
if (maxcnt < counts[p[i]]) {
maxcnt = counts[p[i]]
maxelem = p[i]
}
}
//console.log("Counts", counts, "max: ", maxcnt)
classes.push([maxelem, kclosests, point])
}
return classes
}
score() {
let lookup = {}
for (var d of this.data) {
lookup[d.slice(0, d.length-1)] = d[d.length-1]
}
let examples = this.data.map(v => v.slice(0,v.length-1))
let predictions = this.predict(examples)
let misses = 0
for (var pred of predictions) {
//console.log('prediction', pred)
if (lookup[pred[2]] !== pred[0]) {
misses += 1
}
}
//console.log("k and counts", this.k, counts)
return misses / this.data.length
}
}
function trainTestSplit(data, trainPerc) {
n = Math.floor(data.length * trainPerc);
// generate n unique random numbers for indexing points.
var randIndices = []
while (randIndices.length < data.length) {
var randomnumber = Math.ceil(Math.random() * data.length) + 1;
if (randIndices.indexOf(randomnumber) > -1) continue;
randIndices[randIndices.length] = randomnumber;
}
// use random indices to create the training and test sets
let trainIndices = randIndices.slice(0,n);
//console.log(trainIndices)
let testIndices = randIndices.slice(n,data.length);
let trainSet = trainIndices.map((v,i) => data[i]);
let testSet = data.filter((v, i) => testIndices.indexOf(i) > -1);
return {
classTrainSet:trainSet.map((row) => row[row.length-1]),
classTestSet :testSet.map((row) => row[row.length - 1]),
pointTrainSet:trainSet.map((row) => row.slice(0, trainSet.length)),
pointTestSet :testSet.map((row) => row.slice(0, testSet.length+1))
};
}
function minMaxScaler(scale, data) {
let max = Math.max(...data)
let min = Math.min(...data)
let range = max === min ? 1 : max - min
if (scale === true) {
return data.map(v => {v -= min; return v/range})
}
return data
}
function minkowski (a,b,p) {
//console.log("minkowski of", a, b)
return a.map((v,i) => Math.pow(Math.pow(Math.abs(a[i] - b[i]),p),1/p))
.reduce((a,c) => a + c);
}
(function() {
function drawScatterPlot(data, extraCharts, scale) {
//sets = trainTestSplit(data, 0.7);
//console.log("data",data);
class0 = data.filter(v => v[3] === "banana")
class1 = data.filter(v => v[3] === "apple")
class2 = data.filter(v => v[3] === "unknown")
//console.log("class 0", class0)
var trace1 = {
x: minMaxScaler(scale, class0.map(v => v[0])),
y: minMaxScaler(scale, class0.map(v => v[1])),
z: class0.map(v => v[2]),
mode: 'markers',
type: 'scatter3d',
name: 'Banana',
marker: {
size: 15,
color: 'orange',
// stroke: '#000',
symbol: 'circle',
opacity: 0.8,
line: {
width: 0.5,
}
}
};
//console.log("trace1", trace1.x, trace1.y, class0.map(v => v[2]))
var trace2 = {
x: minMaxScaler(scale, class1.map(v => v[0])),
y: minMaxScaler(scale, class1.map(v => v[1])),
z: minMaxScaler(scale, class1.map(v => v[2])),
mode: 'markers',
type: 'scatter3d',
name: 'Apple',
marker: {
size: 12,
color: '#ee0000',
// stroke: '#000',
symbol: 'circle',
opacity: 0.8,
line: {
width: 0.5,
}
}
};
//console.log("trace2", trace2.x, trace2.y, trace2.z)
var trace3 = {
x: minMaxScaler(scale, class2.map(v => v[0])),
y: minMaxScaler(scale, class2.map(v => v[1])),
z: class2.map(v => v[2]),
mode: 'markers',
type: 'scatter3d',
name: 'Unknown',
marker: {
size: 15,
color: 'green',
// stroke: '#000',
symbol: 'square',
opacity: 1,
line: {
width: 0.5,
}
}
};
//console.log("trace3", trace3.x, trace3.y, trace3.z)
var layout = {
//title:'Insert a point to see its class.'
xaxis: {title: "Weight (g)"},
yaxis: {title: "Color"},
zaxis: {title: "# seeds"},
shapes: extraCharts.slice(),
autosize: false,
width:600,
legend: {
orientation: 'h'
},
height:500,
margin: {
l: 0,
r: 25,
b: 0,
t: 0,
pad: 0
},
};
traces = [ trace1, trace2, trace3]
Plotly.newPlot('knn-vis1',
traces,
layout,
{displayModeBar: true}
);
}
d3.csv("/assets/data/banana_data.csv", function(d) {
var data = [];
for (var i = 0; i < d.length; i++) {
data.push([
parseFloat(d[i]["weight"]),
parseFloat(d[i]["color"]),
parseFloat(d[i]["seeds"]),
d[i]["type"]
]);
}
let scale = false;
drawScatterPlot(data,[], scale);
document.getElementById("knn-1-scale-btn").addEventListener("change",function() {
scale = this.checked
drawScatterPlot(data, [], scale)
})
document.getElementById("knn-1-predict-btn").addEventListener("click",() => {
if (data[data.length-1][3] === "unknown") {
data = data.slice(0,data.length-1)
}
let input = document.getElementById("knn-1-point-input").value;
let k_value = document.getElementById("knn-1-kvalue").value;
input = input.split(',').map(parseFloat)
k_value = parseInt(k_value)
if (input.length < 3 || k_value === undefined || k_value === null){
alert("Incorrect inputs")
return
}
let knn = new KNearestNeighbours(data, k_value, minkowski, [2]);
//let newPoint = input.slice(0,input.length - 1)
labels = knn.predict([input])[0]
console.log("Label was", labels[0])
if (labels[0] === "banana") {
document.getElementById("knn-1-class-img").src =
"/assets/img/banana.svg"
} else {
document.getElementById("knn-1-class-img").src =
"/assets/img/apple.svg"
}
input.push("unknown")
data.push(input)
lines = []
neighbours = labels[1]
for(var n of neighbours) {
//console.log("drawing neighbour", n)
lines.push({
type: 'line',
x0: input[0],
y0: input[1],
z0: input[2],
x1: n[0],
y1: n[1],
z1: n[2],
line: {
color: 'rgb(55, 128, 191)',
width: 3,
},
})
}
drawScatterPlot(data, lines, scale);
});
});
})()
</script>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment