An interactive Support Vector Machine (SVM) classifier. Create at least 6 points of class 1 and class -1 to activate the classifier. Every time you create an additional point the SMO algorithm will then redetermine the optimal separating hyperplane.
Last active
August 29, 2015 14:22
-
-
Save jonahwilliams/e3eef13a85774df70e18 to your computer and use it in GitHub Desktop.
Support Vector Machine II
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<meta charset="utf-8"> | |
<head> | |
<script src="http://d3js.org/d3.v3.min.js"></script> | |
<script src="SMO.js"></script> | |
<style> | |
.axis { | |
font: 10px sans-serif; | |
} | |
path { | |
stroke: steelblue; | |
stroke-width: 2; | |
fill: none; | |
} | |
.axis path, | |
.axis line { | |
fill: none; | |
stroke: #000; | |
shape-rendering: crispEdges; | |
} | |
</style> | |
</head> | |
<body> | |
<form action="update()"> | |
<input type="radio" name="classes" value="-1" checked>-1<br> | |
<input type="radio" name="classes" value="1">1 | |
</form> | |
<script> | |
var data = []; | |
var color = -1; | |
var margin = {top: 80, right: 180, bottom: 80, left: 180}, | |
width = 960 - margin.left - margin.right, | |
height = 500 - margin.top - margin.bottom; | |
var svg = d3.select("body").append("svg") | |
.attr("width", width + margin.left + margin.right) | |
.attr("height", height + margin.top + margin.bottom) | |
.append("g") | |
.attr("transform", "translate(" + margin.left + "," + margin.top + ")"); | |
var y = d3.scale.linear() | |
.domain([0, 10]) | |
.range([height, 0]); | |
var x = d3.scale.linear() | |
.domain([0, 10]) | |
.range([0, width]) | |
var xAxis = d3.svg.axis() | |
.scale(x) | |
.orient("bottom"); | |
var yAxis = d3.svg.axis() | |
.scale(y) | |
.orient("left"); | |
var line = d3.svg.line() | |
.x(function(d) { return x(d.x); }) | |
.y(function(d) { return y(d.y); }); | |
svg.append("g") | |
.attr("class", "x axis") | |
.attr("transform", "translate(0," + height + ")") | |
.call(xAxis) | |
svg.append("g") | |
.attr("class", "y axis") | |
.call(yAxis); | |
var points = svg.selectAll("circle") | |
.data([]).enter() | |
.append("circle") | |
.attr("class", "dot") | |
.attr("r", 2.5) | |
.attr("cx", function(d){ | |
return x(d[0]); | |
}) | |
.attr("cy", function(d){ | |
return y(d[1]) | |
}) | |
.style("fill", function(d){ | |
if (d[2] == -1){ | |
return "red"; | |
} | |
else{ | |
return "blue"; | |
} | |
}); | |
var rect = svg.append("rect") | |
.attr("width", "100%") | |
.attr("height", "100%") | |
.style("opacity", 0) | |
.on("click", mousemove); | |
var boundary = svg.append("path"); | |
var neggut = svg.append("path"); | |
var posgut = svg.append("path"); | |
function mousemove(d, i) { | |
var radios = document.getElementsByName("classes") | |
if (radios[0].checked){ | |
color = -1; | |
} | |
else { | |
color = 1; | |
} | |
var position = d3.mouse(this); | |
var dat = [x.invert(position[0]), y.invert(position[1]), color] | |
data.push(dat); | |
points.data([dat]) | |
.enter().append("circle") | |
.attr("class", "dot") | |
.attr("r", 3.5) | |
.attr("cx", function(d){ | |
return x(d[0]); | |
}) | |
.attr("cy", function(d){ | |
return y(d[1]); | |
}) | |
.style("fill", function(d){ | |
if (d[2] == -1){ | |
return "red"; | |
} | |
else{ | |
return "blue"; | |
} | |
}); | |
if (data.length > 5){ | |
Calculate(); | |
} | |
} | |
function Calculate(){ | |
var X = [], | |
Y = []; | |
for (var i = data.length - 1; i >= 0; i--) { | |
X[i] = [data[i][0], data[i][1]]; | |
Y[i] = data[i][2]; | |
}; | |
V = SMO(X, Y, 1, 0.000001, 30, -1); | |
var w = V.w[0], | |
b = V.b; | |
var decision = [], | |
pg = [], | |
ng = []; | |
for (var i = 0; i < 1000; i++){ | |
decision[i] = {'x' : i / 100, 'y': (-w[0] / w[1]) * (i / 100) - (b / w[1])}; | |
pg[i] = {'x' : i / 100, 'y': (-w[0] / w[1]) * (i / 100) - ((1 + b) / w[1])}; | |
ng[i] = {'x' : i / 100, 'y': (-w[0] / w[1]) * (i / 100) - ((-1 + b) / w[1])}; | |
} | |
boundary.datum(decision) | |
.transition() | |
.attr("class", "line") | |
.attr("id","decision") | |
.attr("d", line); | |
posgut.datum(pg) | |
.transition() | |
.attr("class", "line") | |
.attr("id","decision") | |
.attr("d", line) | |
.style("stroke-dasharray", "10,10"); | |
neggut.datum(ng) | |
.transition() | |
.attr("class", "line") | |
.attr("id","decision") | |
.attr("d", line) | |
.style("stroke-dasharray", "10,10"); | |
} | |
</script> | |
</body> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function SMO(X, y, C, tolerance, max_passes, gamma){ | |
//Performs the SMO algorithm to determine Lagrange Multipliers for a SVM | |
var n = X.length, //Size of Data | |
a = [], //Lagrange Multipliers | |
E = [], //Expected Values | |
b = 0.0, //Threshold | |
passes = 0, //Current Passes | |
num_changed_alphas = 0; | |
//initialize lagrange multipliers | |
for (var i = n - 1; i >= 0; i--) { | |
a[i] = 0.0; | |
E[i] = 0.0; | |
}; | |
while (passes < max_passes) { | |
num_changed_alphas = 0; | |
//Calculate Ei = f(x_i) - y_i | |
for (var i = n - 1; i >= 0; i--) { | |
E[i] = b - y[i]; | |
for (var j = n - 1; j >= 0; j--) { | |
E[i] += (y[j] * a[j] * RBF(X[i], X[j], gamma)); | |
}; | |
if ((y[i] * E[i] < -tolerance && a[i] < C) || (y[i] * E[i] > tolerance && a[i] > 0)){ | |
//Select j != i Randomly | |
do { | |
j = Math.floor(Math.random() * n); | |
} | |
while(j == i); | |
//Calculate Ej = f(x_j) - y_j | |
E[j] = b - y[j]; | |
for (var k = n - 1; k >= 0; k--) { | |
E[j] += (y[k] * a[k] * | |
RBF(X[j], X[k], gamma)) ; | |
}; | |
var alpha_old_i = a[i], | |
alpha_old_j = a[j]; | |
//Compute L and H by 10 or 11 | |
if (y[i] != y[j]){ | |
var L = Math.max(0, a[j] - a[i]); | |
var H = Math.min(C, C + a[j] - a[i]); | |
} | |
else { | |
var L = Math.max(0, a[i] + a[j] - C); | |
var H = Math.min(C, a[j] + a[i]); | |
} | |
if(L == H){ | |
continue; | |
} | |
//Compute nen by 14 | |
var nen = 2 * RBF(X[i], X[j], gamma) - | |
RBF(X[i], X[i], gamma) - | |
RBF(X[j], X[j], gamma); | |
if (nen >= 0){ | |
continue; | |
} | |
//Compute and clip new value for aj using 12 and 15 | |
a[j] = a[j] - ((y[j] * (E[i] - E[j])) / nen); | |
//Clip aj to fall in range | |
if (a[j] > H){ | |
a[j] = H; | |
} | |
else if (a[j] < L){ | |
a[j] = L; | |
} | |
//Check Change | |
if (Math.abs(a[j] - alpha_old_j) < 10e-5){ | |
continue; | |
} | |
//Compute value for ai using 16 | |
a[i] = a[i] + (y[i] * y[j] * (alpha_old_j - a[j])); | |
//Compute b1 and b2 with 17 and 18 | |
var b1 = b - E[i] - y[i] * (a[i] - alpha_old_i) * | |
RBF(X[i], X[i], gamma) - y[j] * | |
(a[j] - alpha_old_j) * RBF(X[i], X[j], gamma); | |
var b2 = b - E[j] - y[i] * (a[i] - alpha_old_i) * | |
RBF(X[i], X[j], gamma) - y[j] * | |
(a[j] - alpha_old_j) * RBF(X[j], X[j], gamma); | |
//Compute b by 19 | |
if ((a[i] > 0) && (a[i] < C)){ | |
b = b1; | |
} | |
else if ((a[j] > 0) && (a[j] < C)){ | |
b = b2; | |
} | |
else { | |
b = (b1 + b2) / 2; | |
} | |
num_changed_alphas += 1; | |
}; | |
}; | |
if (num_changed_alphas == 0){ | |
passes += 1; | |
} | |
else { | |
passes = 0; | |
} | |
}; | |
var w = [[]]; | |
for (var i = X[0].length - 1; i >= 0; i--) { | |
w[0][i] = 0; | |
}; | |
for (var i = n - 1; i >= 0; i--) { | |
for (var j = X[0].length - 1; j >= 0; j--) { | |
w[0][j] += a[i] * X[i][j] * y[i]; | |
}; | |
}; | |
return {'w' : w, 'b' : b}; | |
} | |
function RBF(a, b, gamma){ | |
var d = 0.0; | |
for (var i = a.length - 1; i >= 0; i--) { | |
d += a[i] * b[i] | |
}; | |
return d; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment