Skip to content

Instantly share code, notes, and snippets.

@jonahwilliams
Last active August 29, 2015 14:22
Show Gist options
  • Save jonahwilliams/e3eef13a85774df70e18 to your computer and use it in GitHub Desktop.
Save jonahwilliams/e3eef13a85774df70e18 to your computer and use it in GitHub Desktop.
Support Vector Machine II

An interactive Support Vector Machine (SVM) classifier. Create at least 6 points of class 1 and class -1 to activate the classifier. Every time you create an additional point the SMO algorithm will then redetermine the optimal separating hyperplane.

<!DOCTYPE html>
<meta charset="utf-8">
<head>
<script src="http://d3js.org/d3.v3.min.js"></script>
<script src="SMO.js"></script>
<style>
.axis {
font: 10px sans-serif;
}
path {
stroke: steelblue;
stroke-width: 2;
fill: none;
}
.axis path,
.axis line {
fill: none;
stroke: #000;
shape-rendering: crispEdges;
}
</style>
</head>
<body>
<form action="update()">
<input type="radio" name="classes" value="-1" checked>-1<br>
<input type="radio" name="classes" value="1">1
</form>
<script>
var data = [];
var color = -1;
var margin = {top: 80, right: 180, bottom: 80, left: 180},
width = 960 - margin.left - margin.right,
height = 500 - margin.top - margin.bottom;
var svg = d3.select("body").append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
var y = d3.scale.linear()
.domain([0, 10])
.range([height, 0]);
var x = d3.scale.linear()
.domain([0, 10])
.range([0, width])
var xAxis = d3.svg.axis()
.scale(x)
.orient("bottom");
var yAxis = d3.svg.axis()
.scale(y)
.orient("left");
var line = d3.svg.line()
.x(function(d) { return x(d.x); })
.y(function(d) { return y(d.y); });
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + height + ")")
.call(xAxis)
svg.append("g")
.attr("class", "y axis")
.call(yAxis);
var points = svg.selectAll("circle")
.data([]).enter()
.append("circle")
.attr("class", "dot")
.attr("r", 2.5)
.attr("cx", function(d){
return x(d[0]);
})
.attr("cy", function(d){
return y(d[1])
})
.style("fill", function(d){
if (d[2] == -1){
return "red";
}
else{
return "blue";
}
});
var rect = svg.append("rect")
.attr("width", "100%")
.attr("height", "100%")
.style("opacity", 0)
.on("click", mousemove);
var boundary = svg.append("path");
var neggut = svg.append("path");
var posgut = svg.append("path");
function mousemove(d, i) {
var radios = document.getElementsByName("classes")
if (radios[0].checked){
color = -1;
}
else {
color = 1;
}
var position = d3.mouse(this);
var dat = [x.invert(position[0]), y.invert(position[1]), color]
data.push(dat);
points.data([dat])
.enter().append("circle")
.attr("class", "dot")
.attr("r", 3.5)
.attr("cx", function(d){
return x(d[0]);
})
.attr("cy", function(d){
return y(d[1]);
})
.style("fill", function(d){
if (d[2] == -1){
return "red";
}
else{
return "blue";
}
});
if (data.length > 5){
Calculate();
}
}
function Calculate(){
var X = [],
Y = [];
for (var i = data.length - 1; i >= 0; i--) {
X[i] = [data[i][0], data[i][1]];
Y[i] = data[i][2];
};
V = SMO(X, Y, 1, 0.000001, 30, -1);
var w = V.w[0],
b = V.b;
var decision = [],
pg = [],
ng = [];
for (var i = 0; i < 1000; i++){
decision[i] = {'x' : i / 100, 'y': (-w[0] / w[1]) * (i / 100) - (b / w[1])};
pg[i] = {'x' : i / 100, 'y': (-w[0] / w[1]) * (i / 100) - ((1 + b) / w[1])};
ng[i] = {'x' : i / 100, 'y': (-w[0] / w[1]) * (i / 100) - ((-1 + b) / w[1])};
}
boundary.datum(decision)
.transition()
.attr("class", "line")
.attr("id","decision")
.attr("d", line);
posgut.datum(pg)
.transition()
.attr("class", "line")
.attr("id","decision")
.attr("d", line)
.style("stroke-dasharray", "10,10");
neggut.datum(ng)
.transition()
.attr("class", "line")
.attr("id","decision")
.attr("d", line)
.style("stroke-dasharray", "10,10");
}
</script>
</body>
function SMO(X, y, C, tolerance, max_passes, gamma){
//Performs the SMO algorithm to determine Lagrange Multipliers for a SVM
var n = X.length, //Size of Data
a = [], //Lagrange Multipliers
E = [], //Expected Values
b = 0.0, //Threshold
passes = 0, //Current Passes
num_changed_alphas = 0;
//initialize lagrange multipliers
for (var i = n - 1; i >= 0; i--) {
a[i] = 0.0;
E[i] = 0.0;
};
while (passes < max_passes) {
num_changed_alphas = 0;
//Calculate Ei = f(x_i) - y_i
for (var i = n - 1; i >= 0; i--) {
E[i] = b - y[i];
for (var j = n - 1; j >= 0; j--) {
E[i] += (y[j] * a[j] * RBF(X[i], X[j], gamma));
};
if ((y[i] * E[i] < -tolerance && a[i] < C) || (y[i] * E[i] > tolerance && a[i] > 0)){
//Select j != i Randomly
do {
j = Math.floor(Math.random() * n);
}
while(j == i);
//Calculate Ej = f(x_j) - y_j
E[j] = b - y[j];
for (var k = n - 1; k >= 0; k--) {
E[j] += (y[k] * a[k] *
RBF(X[j], X[k], gamma)) ;
};
var alpha_old_i = a[i],
alpha_old_j = a[j];
//Compute L and H by 10 or 11
if (y[i] != y[j]){
var L = Math.max(0, a[j] - a[i]);
var H = Math.min(C, C + a[j] - a[i]);
}
else {
var L = Math.max(0, a[i] + a[j] - C);
var H = Math.min(C, a[j] + a[i]);
}
if(L == H){
continue;
}
//Compute nen by 14
var nen = 2 * RBF(X[i], X[j], gamma) -
RBF(X[i], X[i], gamma) -
RBF(X[j], X[j], gamma);
if (nen >= 0){
continue;
}
//Compute and clip new value for aj using 12 and 15
a[j] = a[j] - ((y[j] * (E[i] - E[j])) / nen);
//Clip aj to fall in range
if (a[j] > H){
a[j] = H;
}
else if (a[j] < L){
a[j] = L;
}
//Check Change
if (Math.abs(a[j] - alpha_old_j) < 10e-5){
continue;
}
//Compute value for ai using 16
a[i] = a[i] + (y[i] * y[j] * (alpha_old_j - a[j]));
//Compute b1 and b2 with 17 and 18
var b1 = b - E[i] - y[i] * (a[i] - alpha_old_i) *
RBF(X[i], X[i], gamma) - y[j] *
(a[j] - alpha_old_j) * RBF(X[i], X[j], gamma);
var b2 = b - E[j] - y[i] * (a[i] - alpha_old_i) *
RBF(X[i], X[j], gamma) - y[j] *
(a[j] - alpha_old_j) * RBF(X[j], X[j], gamma);
//Compute b by 19
if ((a[i] > 0) && (a[i] < C)){
b = b1;
}
else if ((a[j] > 0) && (a[j] < C)){
b = b2;
}
else {
b = (b1 + b2) / 2;
}
num_changed_alphas += 1;
};
};
if (num_changed_alphas == 0){
passes += 1;
}
else {
passes = 0;
}
};
var w = [[]];
for (var i = X[0].length - 1; i >= 0; i--) {
w[0][i] = 0;
};
for (var i = n - 1; i >= 0; i--) {
for (var j = X[0].length - 1; j >= 0; j--) {
w[0][j] += a[i] * X[i][j] * y[i];
};
};
return {'w' : w, 'b' : b};
}
function RBF(a, b, gamma){
var d = 0.0;
for (var i = a.length - 1; i >= 0; i--) {
d += a[i] * b[i]
};
return d;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment