Inspired by Professor Ng's lectures in the Coursera Machine Learning class, these animations visualize linear regression (1-variable) by using gradient descent. The graph on the left shows the data we are trying to fit, and the hypothesis line as the variables θ0 and θ1 converge. The plot on the right shows the value of the cost function. The animation loops forever, each time starting with a "random" value of θ0 and θ1.
Last active
May 26, 2018 03:43
-
-
Save jaredwinick/31cf37cb0cf5db911eedd54c78e992c1 to your computer and use it in GitHub Desktop.
Visualizing Linear Regression by Gradient Descent
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<meta charset="utf-8"> | |
<style> | |
body { | |
font: 10px sans-serif; | |
} | |
.axis path, | |
.axis line { | |
fill: none; | |
stroke: #000; | |
shape-rendering: crispEdges; | |
} | |
.line { | |
fill: none; | |
stroke: steelblue; | |
stroke-width: 1px; | |
} | |
</style> | |
<body> | |
<script src="http://d3js.org/d3.v4.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3-legend/2.18.0/d3-legend.min.js"></script> | |
<script> | |
var margin = {top: 20, right: 20, bottom: 30, left: 50}, | |
width = 450 - margin.left - margin.right, | |
height = 500 - margin.top - margin.bottom; | |
var minX = 0; | |
var minY = 0; | |
var maxX = 100; | |
var maxY = 100; | |
var minSlope = -2; | |
var maxSlope = 2; | |
var theta0Generator = d3.randomUniform(minY, maxY); | |
var theta1Generator = d3.randomUniform(minSlope, maxSlope); | |
var x = d3.scaleLinear() | |
.domain([minX, maxX]) | |
.range([0, width]); | |
var y = d3.scaleLinear() | |
.domain([minY, maxY]) | |
.range([height, 0]); | |
var t0 = d3.scaleLinear() | |
.domain([minY, maxY]) | |
.range([0, width]); | |
var t1 = d3.scaleLinear() | |
.domain([minSlope, maxSlope]) | |
.range([height, 0]); | |
var xAxis = d3.axisBottom() | |
.scale(x); | |
var yAxis = d3.axisLeft() | |
.scale(y); | |
var t0Axis = d3.axisBottom() | |
.scale(t0); | |
var t1Axis = d3.axisLeft() | |
.scale(t1); | |
var svg = d3.select("body").append("svg") | |
.attr("width", width + margin.left + margin.right) | |
.attr("height", height + margin.top + margin.bottom) | |
.append("g") | |
.attr("transform", "translate(" + margin.left + "," + margin.top + ")"); | |
svg.append("g") | |
.attr("class", "x axis") | |
.attr("transform", "translate(0," + height + ")") | |
.call(xAxis); | |
svg.append("g") | |
.attr("class", "y axis") | |
.call(yAxis) | |
svg.append("text") | |
.attr("text-anchor", "middle") // this makes it easy to centre the text as the transform is applied to the anchor | |
.attr("transform", "translate("+ (width/2) +","+(height+margin.bottom)+")") // centre below axis | |
.text("x"); | |
svg.append("text") | |
.attr("text-anchor", "middle") // this makes it easy to centre the text as the transform is applied to the anchor | |
.attr("transform", "translate("+ (-30) +","+(height/2)+")") // centre below axis | |
.text("y"); | |
var svg2 = d3.select("body").append("svg") | |
.attr("width", width + margin.left + margin.right) | |
.attr("height", height + margin.top + margin.bottom) | |
.append("g") | |
.attr("transform", "translate(" + margin.left + "," + margin.top + ")"); | |
svg2.append("g") | |
.attr("class", "x axis") | |
.attr("transform", "translate(0," + height + ")") | |
.call(t0Axis); | |
svg2.append("g") | |
.attr("class", "y axis") | |
.call(t1Axis) | |
svg2.append("text") | |
.attr("text-anchor", "middle") // this makes it easy to centre the text as the transform is applied to the anchor | |
.attr("transform", "translate("+ (width/2) +","+(height+margin.bottom)+")") // centre below axis | |
.html("Θ0"); | |
svg2.append("text") | |
.attr("text-anchor", "middle") // this makes it easy to centre the text as the transform is applied to the anchor | |
.attr("transform", "translate("+ (-35) +","+(height/2)+")") // centre below axis | |
.html("Θ1"); | |
/** | |
* generate data for y = ax + b but at a bit of randomization | |
* to y so we don't get a perfect line. | |
*/ | |
function generateData(numberOfPoints, minX, maxX, a, b) { | |
var xGenerator = d3.randomUniform(minX, maxX); | |
var yGenerator = d3.randomNormal(0, 10); | |
return d3.range(numberOfPoints).map(function() { | |
var x = xGenerator(); | |
var y = (a * x) + b + yGenerator(); | |
return {x: x, y: y}; | |
}); | |
} | |
// Generate data that we will later fit | |
var data = generateData(200, minX, maxX, 1.06, 26); | |
var circles = svg.selectAll("circle") | |
.data(data) | |
.enter() | |
.append("circle"); | |
circles.attr("cx", function(d, i) { return x(d.x); }) | |
.attr("cy", function(d) { return y(d.y); }) | |
.attr("r", 2); | |
// hypothesis h(x) = theta0 + theta1 * x | |
function h(x, theta0, theta1) { | |
return (theta0 + (theta1 * x)); | |
} | |
/* | |
* returns {theta0, theta1} | |
*/ | |
function gradiantDescentStep(currentTheta0, currentTheta1, alpha, data) { | |
var theta0Sum = data.reduce(function(accumulator, value) { | |
return accumulator + (h(value.x, currentTheta0, currentTheta1) - value.y); | |
}, 0); | |
var theta1Sum = data.reduce(function(accumulator, value) { | |
return accumulator + ((h(value.x, currentTheta0, currentTheta1) - value.y) * value.x); | |
}, 0); | |
var newTheta0 = currentTheta0 - (alpha * (1.0 / data.length) * theta0Sum); | |
var newTheta1 = currentTheta1 - (alpha * (1.0 / data.length) * theta1Sum); | |
return {theta0: newTheta0, theta1: newTheta1}; | |
} | |
function calculateCost(currentTheta0, currentTheta1, data) { | |
var sum = data.reduce(function(accumulator, value) { | |
return accumulator + Math.pow((h(value.x, currentTheta0, currentTheta1) - value.y),2); | |
}, 0); | |
return sum * (1.0 / (2 * data.length)); | |
} | |
function calculateLineData(theta0, theta1) { | |
var y0 = h(minX, theta0, theta1); | |
var y1 = h(maxX, theta0, theta1); | |
return [{x: minX, y: y0}, {x: maxX, y: y1}]; | |
} | |
/* | |
* This is used to artifically "slow down" the first | |
* few steps of the gradient descent so we can see | |
* the line and cost better at the start when the | |
* variables are changing quickly | |
*/ | |
function numberOfStepsForIteration(iteration) { | |
if (iteration < 20) { | |
return 1; | |
} | |
return Math.min(iteration * 2, 1000); | |
} | |
var costData = []; | |
var costScale = d3.scaleLog() | |
.domain([ 50, 100, 400, 1000, 4000 ]) | |
.range([d3.rgb("#2c7bb6"), d3.rgb('#00ccbc'), d3.rgb('#ffff8c'), d3.rgb('#f29e2e'), d3.rgb('#d7191c')]); | |
svg2.append("g") | |
.attr("class", "legendLog") | |
.attr("transform", "translate(" + (width-40) + ",10)"); | |
var legend = d3.legendColor() | |
.cells([50, 100, 400, 1000, 4000]) | |
.title("Cost") | |
.scale(costScale); | |
svg2.select(".legendLog") | |
.call(legend); | |
function runGradientDescent() { | |
var theta0 = theta0Generator(); | |
var theta1 = theta1Generator(); | |
var lineData = calculateLineData(theta0, theta1); | |
var l = svg.append("line") | |
.attr("class", "line") | |
.attr("x1", x(lineData[0].x)) | |
.attr("y1", y(lineData[0].y)) | |
.attr("x2", x(lineData[1].x)) | |
.attr("y2", y(lineData[1].y)); | |
var iteration = 1; | |
var t = d3.timer(function() { | |
lineData = calculateLineData(theta0, theta1); | |
l.attr("x1", x(lineData[0].x)) | |
.attr("y1", y(lineData[0].y)) | |
.attr("x2", x(lineData[1].x)) | |
.attr("y2", y(lineData[1].y)); | |
// add a new point to the cost data to render | |
var cost = calculateCost(theta0, theta1, data); | |
costData.push({x: theta0, y: theta1, cost: cost}); | |
console.log("cost:" + calculateCost(theta0, theta1, data)); | |
// we are just running a fixed number of iterations | |
// as opposed to checking for convergence | |
if (iteration > 30000) { | |
t.stop(); | |
// start all over again | |
runGradientDescent(); | |
} | |
previousCost = cost; | |
svg2.selectAll("circle") | |
.data(costData) | |
.enter().append("circle") | |
.attr("cx", function(d, i) { return t0(d.x); }) | |
.attr("cy", function(d) { return t1(d.y); }) | |
.attr("r", 4) | |
.style("fill", function(d) { return costScale(d.cost); }); | |
for (var i = 0; i < numberOfStepsForIteration(iteration); ++i) { | |
var update = gradiantDescentStep(theta0, theta1, .0005, data); | |
theta0 = update.theta0; | |
theta1 = update.theta1; | |
++iteration; | |
} | |
}, 500); | |
} | |
runGradientDescent(); | |
</script> | |
</body> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment