Last active
June 8, 2018 02:13
-
-
Save mik30s/dea6a0762d335ebfc05021d13b9d9daf to your computer and use it in GitHub Desktop.
K Nearest Neighbors in Javascript
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<div id="knn-react-app" class="embedded-app" style="text-align: center;"> | |
<button id="knn-1-predict-btn"> <i class=""></i> Predict</button> | |
<input type="checkbox" id="knn-1-scale-btn" /> | |
<label for="subscribeNews">Scale?</label> | |
<input id="knn-1-kvalue" placeholder="# of neighbours eg. 3"/> | |
<input id="knn-1-point-input" placeholder="weight,color,seeds eg. 371,3,1" /> | |
<i class="fas fa-equals"></i> | |
<img width="25" height="25" id="knn-1-class-img" /> | |
</div> | |
<div id="knn-vis1" style="width: 100%; height: 520px;"></div> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.17/d3.min.js"></script> | |
<script type="text/javascript" async> | |
class KNearestNeighbours { | |
constructor(d, k, metric, args) { | |
this.k = (k % 2 == 0) ? k+1 : k; | |
this.metric = metric | |
this.args = args | |
this.data = d.slice() | |
} | |
predict(examples) { | |
//console.log("Predicing on", this.data) | |
let classes = [] | |
for (let point of examples) { | |
//console.log("Working on point", point) | |
let xsorted = this.data.sort((a, b) => { | |
let flength = a.length -1 | |
//console.log("Comparing: ", a.slice(0,flength), "and", b.slice(0,flength)) | |
let r = Math.ceil(this.metric(point,a.slice(0,flength),...this.args)) | |
let l = Math.ceil(this.metric(point,b.slice(0,flength),...this.args)) | |
//console.log("r, l", r, l) | |
return r - l | |
}) | |
let kclosests = xsorted.slice(0,this.k) | |
//console.log(this.k,"closests", kclosests) | |
let counts = {} | |
let maxcnt = 0 | |
let maxelem = null | |
for (var p of kclosests) { | |
let i = p.length-1 | |
counts[p[i]] = (counts[p[i]] === undefined) ? 1 : ++counts[p[i]] | |
if (maxcnt < counts[p[i]]) { | |
maxcnt = counts[p[i]] | |
maxelem = p[i] | |
} | |
} | |
//console.log("Counts", counts, "max: ", maxcnt) | |
classes.push([maxelem, kclosests, point]) | |
} | |
return classes | |
} | |
score() { | |
let lookup = {} | |
for (var d of this.data) { | |
lookup[d.slice(0, d.length-1)] = d[d.length-1] | |
} | |
let examples = this.data.map(v => v.slice(0,v.length-1)) | |
let predictions = this.predict(examples) | |
let misses = 0 | |
for (var pred of predictions) { | |
//console.log('prediction', pred) | |
if (lookup[pred[2]] !== pred[0]) { | |
misses += 1 | |
} | |
} | |
//console.log("k and counts", this.k, counts) | |
return misses / this.data.length | |
} | |
} | |
function trainTestSplit(data, trainPerc) { | |
n = Math.floor(data.length * trainPerc); | |
// generate n unique random numbers for indexing points. | |
var randIndices = [] | |
while (randIndices.length < data.length) { | |
var randomnumber = Math.ceil(Math.random() * data.length) + 1; | |
if (randIndices.indexOf(randomnumber) > -1) continue; | |
randIndices[randIndices.length] = randomnumber; | |
} | |
// use random indices to create the training and test sets | |
let trainIndices = randIndices.slice(0,n); | |
//console.log(trainIndices) | |
let testIndices = randIndices.slice(n,data.length); | |
let trainSet = trainIndices.map((v,i) => data[i]); | |
let testSet = data.filter((v, i) => testIndices.indexOf(i) > -1); | |
return { | |
classTrainSet:trainSet.map((row) => row[row.length-1]), | |
classTestSet :testSet.map((row) => row[row.length - 1]), | |
pointTrainSet:trainSet.map((row) => row.slice(0, trainSet.length)), | |
pointTestSet :testSet.map((row) => row.slice(0, testSet.length+1)) | |
}; | |
} | |
function minMaxScaler(scale, data) { | |
let max = Math.max(...data) | |
let min = Math.min(...data) | |
let range = max === min ? 1 : max - min | |
if (scale === true) { | |
return data.map(v => {v -= min; return v/range}) | |
} | |
return data | |
} | |
function minkowski (a,b,p) { | |
//console.log("minkowski of", a, b) | |
return a.map((v,i) => Math.pow(Math.pow(Math.abs(a[i] - b[i]),p),1/p)) | |
.reduce((a,c) => a + c); | |
} | |
(function() { | |
function drawScatterPlot(data, extraCharts, scale) { | |
//sets = trainTestSplit(data, 0.7); | |
//console.log("data",data); | |
class0 = data.filter(v => v[3] === "banana") | |
class1 = data.filter(v => v[3] === "apple") | |
class2 = data.filter(v => v[3] === "unknown") | |
//console.log("class 0", class0) | |
var trace1 = { | |
x: minMaxScaler(scale, class0.map(v => v[0])), | |
y: minMaxScaler(scale, class0.map(v => v[1])), | |
z: class0.map(v => v[2]), | |
mode: 'markers', | |
type: 'scatter3d', | |
name: 'Banana', | |
marker: { | |
size: 15, | |
color: 'orange', | |
// stroke: '#000', | |
symbol: 'circle', | |
opacity: 0.8, | |
line: { | |
width: 0.5, | |
} | |
} | |
}; | |
//console.log("trace1", trace1.x, trace1.y, class0.map(v => v[2])) | |
var trace2 = { | |
x: minMaxScaler(scale, class1.map(v => v[0])), | |
y: minMaxScaler(scale, class1.map(v => v[1])), | |
z: minMaxScaler(scale, class1.map(v => v[2])), | |
mode: 'markers', | |
type: 'scatter3d', | |
name: 'Apple', | |
marker: { | |
size: 12, | |
color: '#ee0000', | |
// stroke: '#000', | |
symbol: 'circle', | |
opacity: 0.8, | |
line: { | |
width: 0.5, | |
} | |
} | |
}; | |
//console.log("trace2", trace2.x, trace2.y, trace2.z) | |
var trace3 = { | |
x: minMaxScaler(scale, class2.map(v => v[0])), | |
y: minMaxScaler(scale, class2.map(v => v[1])), | |
z: class2.map(v => v[2]), | |
mode: 'markers', | |
type: 'scatter3d', | |
name: 'Unknown', | |
marker: { | |
size: 15, | |
color: 'green', | |
// stroke: '#000', | |
symbol: 'square', | |
opacity: 1, | |
line: { | |
width: 0.5, | |
} | |
} | |
}; | |
//console.log("trace3", trace3.x, trace3.y, trace3.z) | |
var layout = { | |
//title:'Insert a point to see its class.' | |
xaxis: {title: "Weight (g)"}, | |
yaxis: {title: "Color"}, | |
zaxis: {title: "# seeds"}, | |
shapes: extraCharts.slice(), | |
autosize: false, | |
width:600, | |
legend: { | |
orientation: 'h' | |
}, | |
height:500, | |
margin: { | |
l: 0, | |
r: 25, | |
b: 0, | |
t: 0, | |
pad: 0 | |
}, | |
}; | |
traces = [ trace1, trace2, trace3] | |
Plotly.newPlot('knn-vis1', | |
traces, | |
layout, | |
{displayModeBar: true} | |
); | |
} | |
d3.csv("/assets/data/banana_data.csv", function(d) { | |
var data = []; | |
for (var i = 0; i < d.length; i++) { | |
data.push([ | |
parseFloat(d[i]["weight"]), | |
parseFloat(d[i]["color"]), | |
parseFloat(d[i]["seeds"]), | |
d[i]["type"] | |
]); | |
} | |
let scale = false; | |
drawScatterPlot(data,[], scale); | |
document.getElementById("knn-1-scale-btn").addEventListener("change",function() { | |
scale = this.checked | |
drawScatterPlot(data, [], scale) | |
}) | |
document.getElementById("knn-1-predict-btn").addEventListener("click",() => { | |
if (data[data.length-1][3] === "unknown") { | |
data = data.slice(0,data.length-1) | |
} | |
let input = document.getElementById("knn-1-point-input").value; | |
let k_value = document.getElementById("knn-1-kvalue").value; | |
input = input.split(',').map(parseFloat) | |
k_value = parseInt(k_value) | |
if (input.length < 3 || k_value === undefined || k_value === null){ | |
alert("Incorrect inputs") | |
return | |
} | |
let knn = new KNearestNeighbours(data, k_value, minkowski, [2]); | |
//let newPoint = input.slice(0,input.length - 1) | |
labels = knn.predict([input])[0] | |
console.log("Label was", labels[0]) | |
if (labels[0] === "banana") { | |
document.getElementById("knn-1-class-img").src = | |
"/assets/img/banana.svg" | |
} else { | |
document.getElementById("knn-1-class-img").src = | |
"/assets/img/apple.svg" | |
} | |
input.push("unknown") | |
data.push(input) | |
lines = [] | |
neighbours = labels[1] | |
for(var n of neighbours) { | |
//console.log("drawing neighbour", n) | |
lines.push({ | |
type: 'line', | |
x0: input[0], | |
y0: input[1], | |
z0: input[2], | |
x1: n[0], | |
y1: n[1], | |
z1: n[2], | |
line: { | |
color: 'rgb(55, 128, 191)', | |
width: 3, | |
}, | |
}) | |
} | |
drawScatterPlot(data, lines, scale); | |
}); | |
}); | |
})() | |
</script> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment