Skip to content

Instantly share code, notes, and snippets.

@feyderm
Last active February 19, 2018 12:59
Show Gist options
  • Save feyderm/dd30a2ceea64826d7e1a7695c17dcdd6 to your computer and use it in GitHub Desktop.
Save feyderm/dd30a2ceea64826d7e1a7695c17dcdd6 to your computer and use it in GitHub Desktop.
Exploring Gradient Decent Parameters for Linear Regression
population profit
6.1101 17.592
5.5277 9.1302
8.5186 13.662
7.0032 11.854
5.8598 6.8233
8.3829 11.886
7.4764 4.3483
8.5781 12
6.4862 6.5987
5.0546 3.8166
5.7107 3.2522
14.164 15.505
5.734 3.1551
8.4084 7.2258
5.6407 0.71618
5.3794 3.5129
6.3654 5.3048
5.1301 0.56077
6.4296 3.6518
7.0708 5.3893
6.1891 3.1386
20.27 21.767
5.4901 4.263
6.3261 5.1875
5.5649 3.0825
18.945 22.638
12.828 13.501
10.957 7.0467
13.176 14.692
22.203 24.147
5.2524 -1.22
6.5894 5.9966
9.2482 12.134
5.8918 1.8495
8.2111 6.5426
7.9334 4.5623
8.0959 4.1164
5.6063 3.3928
12.836 10.117
6.3534 5.4974
5.4069 0.55657
6.8825 3.9115
11.708 5.3854
5.7737 2.4406
7.8247 6.7318
7.0931 1.0463
5.0702 5.1337
5.8014 1.844
11.7 8.0043
5.5416 1.0179
7.5402 6.7504
5.3077 1.8396
7.4239 4.2885
7.6031 4.9981
6.3328 1.4233
6.3589 -1.4211
6.2742 2.4756
5.6397 4.6042
9.3102 3.9624
9.4536 5.4141
8.8254 5.1694
5.1793 -0.74279
21.279 17.929
14.908 12.054
18.959 17.054
7.2182 4.8852
8.2951 5.7442
10.236 7.7754
5.4994 1.0173
20.341 20.992
10.136 6.6799
7.3345 4.0259
6.0062 1.2784
7.2259 3.3411
5.0269 -2.6807
6.5479 0.29678
7.5386 3.8845
5.0365 5.7014
10.274 6.7526
5.1077 2.0576
5.7292 0.47953
5.1884 0.20421
6.3557 0.67861
9.7687 7.5435
6.5159 5.3436
8.5172 4.2415
9.1802 6.7981
6.002 0.92695
5.5204 0.152
5.0594 2.8214
5.7077 1.8451
7.6366 4.2959
5.8707 7.2029
5.3054 1.9869
8.2934 0.14454
13.394 9.0551
5.4369 0.61705
<!DOCTYPE html>
<meta charset="utf-8">
<style>
input {
margin-right: 15px;
}
body {
font: 10px sans-serif;
}
.dot {
fill: #d73027;
}
.axis path,
.axis line {
fill: none;
stroke: #000;
shape-rendering: crispEdges;
}
.axis text {
text-anchor: middle;
font-weight: bold;
}
.line {
fill: none;
stroke: black;
stroke-width: 1px;
opacity: 0.8;
}
#hypothesis_fx {
text-anchor: middle;
font-size: 25px;
}
</style>
<body>
<form action="">
Number of Iterations: <input type="text" name="iterationNumber" value="100">
Learning Rate: <input type="text" name="alpha" value=0.001>
Initial Intercept: <input type="text" name="theta0" value=0>
Initial Slope: <input type="text" name="theta1" value=0>
<input type="button" value="Submit" onClick=updateParams(this.form)><br>
</form>
<script src="https://d3js.org/d3.v3.min.js"></script>
<script>
var margin = { top: 20, right: 20, bottom: 50, left: 40 },
width = 960 - margin.left - margin.right,
height = 500 - margin.top - margin.bottom;
var format = d3.format(".3f");
var x = d3.scale.linear()
.range([0, width]);
var x_axis = d3.svg.axis()
.scale(x)
.orient("bottom");
var y = d3.scale.linear()
.range([height, 0]);
var y_axis = d3.svg.axis()
.scale(y)
.orient("left");
var svg = d3.select("body")
.append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
var hyp = svg.append("text")
.attr("id", "hypothesis_fx")
.attr("x", 200)
.attr("y", 50);
var cost_plot = svg.append("g")
.attr("id", "cost_plot")
.attr("transform", "translate(700, 230)");
d3.csv("data.csv", function(error, data) {
data.forEach(function(d) {
d.population = +d.population;
d.profit = +d.profit;
});
x.domain([0, d3.max(data, function(d) { return d.population; })]).nice();
y.domain(d3.extent(data, function(d) { return d.profit; })).nice();
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + height + ")")
.call(x_axis)
.append("text")
.attr("x", width / 2)
.attr("y", 27)
.text("Population of City in 10,000s");
svg.append("g")
.attr("class", "y axis")
.call(y_axis)
.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -height / 2)
.attr("y", -33)
.attr("dy", ".71em")
.text("Profit in $10,000s")
svg.append("g")
.attr("id", "scatterplot")
.selectAll(".dot")
.data(data)
.enter()
.append("circle")
.attr("class", "dot")
.attr("r", 3.5)
.attr("cx", function(d) { return x(d.population); })
.attr("cy", function(d) { return y(d.profit); });
runGradientDescent(100, 0.001, 0, 0);
});
function updateParams(form) {
var iterationNumber = +form.iterationNumber.value,
alpha = +form.alpha.value,
theta0 = +form.theta0.value,
theta1 = +form.theta1.value;
runGradientDescent(iterationNumber, alpha, theta0, theta1);
}
function resetPlot() {
d3.select("#line").remove();
d3.select("#cost_line").remove();
d3.selectAll(".axis_cost").remove();
}
function runGradientDescent(iterationNumber, alpha, theta0, theta1) {
resetPlot();
var data = d3.selectAll("circle").data();
var iteration = 0,
m = data.length;
var xMin = x.domain()[0],
xMax = x.domain()[1],
yMin = y.domain()[0],
yMax = y.domain()[1];
var line = svg.append("line")
.attr("class", "line")
.attr("id", "line")
.attr("x1",x(xMin))
.attr("y1",y(theta1 * xMin + theta0))
.attr("x2",x(xMax))
.attr("y2",y(theta1 * xMax + theta0));
hyp.text("hθ(x) = 0 + 0x");
function computeCost (data, theta0, theta1) {
var cost = 0;
data.forEach(function(d) {
cost += Math.pow((theta1 * d.population + theta0 - d.profit),2);
});
return cost/(2 * m);
};
var max_cost = computeCost(data, theta0, theta1);
var d_cost = [];
d3.timer(function() {
d_cost.push({ "iteration" : iteration,
"cost" : computeCost(data, theta0, theta1)
});
var temp0 = theta0 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit); }));
var temp1 = theta1 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit) * d.population ; }));
theta0 = temp0;
theta1 = temp1;
line.attr("x1",x( xMin ))
.attr("y1",y( theta1 * xMin + theta0 ))
.attr("x2",x( xMax ))
.attr("y2",y( theta1 * xMax + theta0 ));
hyp.text("hθ(x) = " + format(theta0) + " + " + format(theta1) + "x");
if (iteration == iterationNumber) {
plotCost(d_cost, iterationNumber, max_cost);
}
return ++iteration > iterationNumber;
}, 200);
};
function plotCost(d_cost, iterationNumber, max_cost) {
var x_cost = d3.scale.linear()
.domain([0, iterationNumber])
.range([0, 150]);
var x_axis_cost = d3.svg.axis()
.scale(x_cost)
.orient("bottom")
.ticks(4);
var y_cost = d3.scale.linear()
.domain([0, max_cost])
.range([150, 0]);
var y_axis_cost = d3.svg.axis()
.scale(y_cost)
.orient("left")
.ticks(4);
cost_line = d3.svg.line()
.x(function(d) { return x_cost(d.iteration); })
.y(function(d) { return y_cost(d.cost); })
.interpolate(["basis"]);
cost_plot.append("path")
.datum(d_cost)
.attr("id", "cost_line")
.attr("d", cost_line)
.attr("stroke", "black")
.attr("fill", "none");
cost_plot.append("g")
.attr("class", "y axis axis_cost")
.call(y_axis_cost)
.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -80)
.attr("y", -35)
.attr("dy", ".71em")
.text("J(θ)");
cost_plot.append("g")
.attr("class", "x axis axis_cost")
.attr("transform", "translate(0, 150)")
.call(x_axis_cost)
.append("text")
.attr("x", 80)
.attr("y", 30)
.text("Num. of Iterations");
}
</script>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment