Forked from Brice Pierre de la Briere to include user-supplied parameters and plot of J(θ) over iterations.
Last active
February 19, 2018 12:59
-
-
Save feyderm/dd30a2ceea64826d7e1a7695c17dcdd6 to your computer and use it in GitHub Desktop.
Exploring Gradient Decent Parameters for Linear Regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
population | profit | |
---|---|---|
6.1101 | 17.592 | |
5.5277 | 9.1302 | |
8.5186 | 13.662 | |
7.0032 | 11.854 | |
5.8598 | 6.8233 | |
8.3829 | 11.886 | |
7.4764 | 4.3483 | |
8.5781 | 12 | |
6.4862 | 6.5987 | |
5.0546 | 3.8166 | |
5.7107 | 3.2522 | |
14.164 | 15.505 | |
5.734 | 3.1551 | |
8.4084 | 7.2258 | |
5.6407 | 0.71618 | |
5.3794 | 3.5129 | |
6.3654 | 5.3048 | |
5.1301 | 0.56077 | |
6.4296 | 3.6518 | |
7.0708 | 5.3893 | |
6.1891 | 3.1386 | |
20.27 | 21.767 | |
5.4901 | 4.263 | |
6.3261 | 5.1875 | |
5.5649 | 3.0825 | |
18.945 | 22.638 | |
12.828 | 13.501 | |
10.957 | 7.0467 | |
13.176 | 14.692 | |
22.203 | 24.147 | |
5.2524 | -1.22 | |
6.5894 | 5.9966 | |
9.2482 | 12.134 | |
5.8918 | 1.8495 | |
8.2111 | 6.5426 | |
7.9334 | 4.5623 | |
8.0959 | 4.1164 | |
5.6063 | 3.3928 | |
12.836 | 10.117 | |
6.3534 | 5.4974 | |
5.4069 | 0.55657 | |
6.8825 | 3.9115 | |
11.708 | 5.3854 | |
5.7737 | 2.4406 | |
7.8247 | 6.7318 | |
7.0931 | 1.0463 | |
5.0702 | 5.1337 | |
5.8014 | 1.844 | |
11.7 | 8.0043 | |
5.5416 | 1.0179 | |
7.5402 | 6.7504 | |
5.3077 | 1.8396 | |
7.4239 | 4.2885 | |
7.6031 | 4.9981 | |
6.3328 | 1.4233 | |
6.3589 | -1.4211 | |
6.2742 | 2.4756 | |
5.6397 | 4.6042 | |
9.3102 | 3.9624 | |
9.4536 | 5.4141 | |
8.8254 | 5.1694 | |
5.1793 | -0.74279 | |
21.279 | 17.929 | |
14.908 | 12.054 | |
18.959 | 17.054 | |
7.2182 | 4.8852 | |
8.2951 | 5.7442 | |
10.236 | 7.7754 | |
5.4994 | 1.0173 | |
20.341 | 20.992 | |
10.136 | 6.6799 | |
7.3345 | 4.0259 | |
6.0062 | 1.2784 | |
7.2259 | 3.3411 | |
5.0269 | -2.6807 | |
6.5479 | 0.29678 | |
7.5386 | 3.8845 | |
5.0365 | 5.7014 | |
10.274 | 6.7526 | |
5.1077 | 2.0576 | |
5.7292 | 0.47953 | |
5.1884 | 0.20421 | |
6.3557 | 0.67861 | |
9.7687 | 7.5435 | |
6.5159 | 5.3436 | |
8.5172 | 4.2415 | |
9.1802 | 6.7981 | |
6.002 | 0.92695 | |
5.5204 | 0.152 | |
5.0594 | 2.8214 | |
5.7077 | 1.8451 | |
7.6366 | 4.2959 | |
5.8707 | 7.2029 | |
5.3054 | 1.9869 | |
8.2934 | 0.14454 | |
13.394 | 9.0551 | |
5.4369 | 0.61705 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<meta charset="utf-8"> | |
<style> | |
input { | |
margin-right: 15px; | |
} | |
body { | |
font: 10px sans-serif; | |
} | |
.dot { | |
fill: #d73027; | |
} | |
.axis path, | |
.axis line { | |
fill: none; | |
stroke: #000; | |
shape-rendering: crispEdges; | |
} | |
.axis text { | |
text-anchor: middle; | |
font-weight: bold; | |
} | |
.line { | |
fill: none; | |
stroke: black; | |
stroke-width: 1px; | |
opacity: 0.8; | |
} | |
#hypothesis_fx { | |
text-anchor: middle; | |
font-size: 25px; | |
} | |
</style> | |
<body> | |
<form action=""> | |
Number of Iterations: <input type="text" name="iterationNumber" value="100"> | |
Learning Rate: <input type="text" name="alpha" value=0.001> | |
Initial Intercept: <input type="text" name="theta0" value=0> | |
Initial Slope: <input type="text" name="theta1" value=0> | |
<input type="button" value="Submit" onClick=updateParams(this.form)><br> | |
</form> | |
<script src="https://d3js.org/d3.v3.min.js"></script> | |
<script> | |
var margin = { top: 20, right: 20, bottom: 50, left: 40 }, | |
width = 960 - margin.left - margin.right, | |
height = 500 - margin.top - margin.bottom; | |
var format = d3.format(".3f"); | |
var x = d3.scale.linear() | |
.range([0, width]); | |
var x_axis = d3.svg.axis() | |
.scale(x) | |
.orient("bottom"); | |
var y = d3.scale.linear() | |
.range([height, 0]); | |
var y_axis = d3.svg.axis() | |
.scale(y) | |
.orient("left"); | |
var svg = d3.select("body") | |
.append("svg") | |
.attr("width", width + margin.left + margin.right) | |
.attr("height", height + margin.top + margin.bottom) | |
.append("g") | |
.attr("transform", "translate(" + margin.left + "," + margin.top + ")"); | |
var hyp = svg.append("text") | |
.attr("id", "hypothesis_fx") | |
.attr("x", 200) | |
.attr("y", 50); | |
var cost_plot = svg.append("g") | |
.attr("id", "cost_plot") | |
.attr("transform", "translate(700, 230)"); | |
d3.csv("data.csv", function(error, data) { | |
data.forEach(function(d) { | |
d.population = +d.population; | |
d.profit = +d.profit; | |
}); | |
x.domain([0, d3.max(data, function(d) { return d.population; })]).nice(); | |
y.domain(d3.extent(data, function(d) { return d.profit; })).nice(); | |
svg.append("g") | |
.attr("class", "x axis") | |
.attr("transform", "translate(0," + height + ")") | |
.call(x_axis) | |
.append("text") | |
.attr("x", width / 2) | |
.attr("y", 27) | |
.text("Population of City in 10,000s"); | |
svg.append("g") | |
.attr("class", "y axis") | |
.call(y_axis) | |
.append("text") | |
.attr("transform", "rotate(-90)") | |
.attr("x", -height / 2) | |
.attr("y", -33) | |
.attr("dy", ".71em") | |
.text("Profit in $10,000s") | |
svg.append("g") | |
.attr("id", "scatterplot") | |
.selectAll(".dot") | |
.data(data) | |
.enter() | |
.append("circle") | |
.attr("class", "dot") | |
.attr("r", 3.5) | |
.attr("cx", function(d) { return x(d.population); }) | |
.attr("cy", function(d) { return y(d.profit); }); | |
runGradientDescent(100, 0.001, 0, 0); | |
}); | |
function updateParams(form) { | |
var iterationNumber = +form.iterationNumber.value, | |
alpha = +form.alpha.value, | |
theta0 = +form.theta0.value, | |
theta1 = +form.theta1.value; | |
runGradientDescent(iterationNumber, alpha, theta0, theta1); | |
} | |
function resetPlot() { | |
d3.select("#line").remove(); | |
d3.select("#cost_line").remove(); | |
d3.selectAll(".axis_cost").remove(); | |
} | |
function runGradientDescent(iterationNumber, alpha, theta0, theta1) { | |
resetPlot(); | |
var data = d3.selectAll("circle").data(); | |
var iteration = 0, | |
m = data.length; | |
var xMin = x.domain()[0], | |
xMax = x.domain()[1], | |
yMin = y.domain()[0], | |
yMax = y.domain()[1]; | |
var line = svg.append("line") | |
.attr("class", "line") | |
.attr("id", "line") | |
.attr("x1",x(xMin)) | |
.attr("y1",y(theta1 * xMin + theta0)) | |
.attr("x2",x(xMax)) | |
.attr("y2",y(theta1 * xMax + theta0)); | |
hyp.text("hθ(x) = 0 + 0x"); | |
function computeCost (data, theta0, theta1) { | |
var cost = 0; | |
data.forEach(function(d) { | |
cost += Math.pow((theta1 * d.population + theta0 - d.profit),2); | |
}); | |
return cost/(2 * m); | |
}; | |
var max_cost = computeCost(data, theta0, theta1); | |
var d_cost = []; | |
d3.timer(function() { | |
d_cost.push({ "iteration" : iteration, | |
"cost" : computeCost(data, theta0, theta1) | |
}); | |
var temp0 = theta0 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit); })); | |
var temp1 = theta1 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit) * d.population ; })); | |
theta0 = temp0; | |
theta1 = temp1; | |
line.attr("x1",x( xMin )) | |
.attr("y1",y( theta1 * xMin + theta0 )) | |
.attr("x2",x( xMax )) | |
.attr("y2",y( theta1 * xMax + theta0 )); | |
hyp.text("hθ(x) = " + format(theta0) + " + " + format(theta1) + "x"); | |
if (iteration == iterationNumber) { | |
plotCost(d_cost, iterationNumber, max_cost); | |
} | |
return ++iteration > iterationNumber; | |
}, 200); | |
}; | |
function plotCost(d_cost, iterationNumber, max_cost) { | |
var x_cost = d3.scale.linear() | |
.domain([0, iterationNumber]) | |
.range([0, 150]); | |
var x_axis_cost = d3.svg.axis() | |
.scale(x_cost) | |
.orient("bottom") | |
.ticks(4); | |
var y_cost = d3.scale.linear() | |
.domain([0, max_cost]) | |
.range([150, 0]); | |
var y_axis_cost = d3.svg.axis() | |
.scale(y_cost) | |
.orient("left") | |
.ticks(4); | |
cost_line = d3.svg.line() | |
.x(function(d) { return x_cost(d.iteration); }) | |
.y(function(d) { return y_cost(d.cost); }) | |
.interpolate(["basis"]); | |
cost_plot.append("path") | |
.datum(d_cost) | |
.attr("id", "cost_line") | |
.attr("d", cost_line) | |
.attr("stroke", "black") | |
.attr("fill", "none"); | |
cost_plot.append("g") | |
.attr("class", "y axis axis_cost") | |
.call(y_axis_cost) | |
.append("text") | |
.attr("transform", "rotate(-90)") | |
.attr("x", -80) | |
.attr("y", -35) | |
.attr("dy", ".71em") | |
.text("J(θ)"); | |
cost_plot.append("g") | |
.attr("class", "x axis axis_cost") | |
.attr("transform", "translate(0, 150)") | |
.call(x_axis_cost) | |
.append("text") | |
.attr("x", 80) | |
.attr("y", 30) | |
.text("Num. of Iterations"); | |
} | |
</script> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment