Update of bl.ock to include user-supplied parameters and dragging of data points.
Last active
February 23, 2017 17:51
-
-
Save feyderm/b415454761a825285653913a9975c935 to your computer and use it in GitHub Desktop.
Exploring Gradient Decent Parameters for Logistic Regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<meta charset="utf-8"> | |
<style> | |
form { | |
font-size: 11px; | |
font-family: sans-serif; | |
} | |
input { | |
margin-right: 10px; | |
margin-bottom: 10px; | |
} | |
#button { | |
margin-left: 30px; | |
} | |
text { | |
font-family: sans-serif; | |
fill: #000000; | |
} | |
.pts { | |
stroke: #595959; | |
} | |
.group1 { | |
fill: steelblue; | |
} | |
.group2 { | |
fill: red; | |
} | |
#dec_boundary { | |
fill: none; | |
stroke: #000000; | |
stroke-width: 2px; | |
opacity: 0.6; | |
} | |
</style> | |
<body> | |
<form action=""> | |
Number of Iterations: <input type="text" name="iterationNumber" value="400" size="4"> | |
Learning Rate: <input type="text" name="alpha" value=0.0004 size="7"><br> | |
Theta 0: <input type="text" name="theta0" value=-24.0 size="6"> | |
Theta 1: <input type="text" name="theta1" value=0.5 size="6"> | |
Theta 2: <input type="text" name="theta2" value=0.2 size="6"> | |
<input id="button" type="button" value="Submit" onClick=updateParams(this.form)><br> | |
</form> | |
<!--viz--> | |
<div id="chart"></div> | |
<script src="https://d3js.org/d3.v4.min.js"></script> | |
<script src="http://feyderm.github.io/math/math.js"></script> | |
<script type="text/javascript"> | |
var margin = { top: 20, right: 0, bottom: 50, left: 85 }, | |
svg_dx = 500, | |
svg_dy = 400, | |
plot_dx = svg_dx - margin.right - margin.left, | |
plot_dy = svg_dy - margin.top - margin.bottom; | |
var xPos = d3.scaleLinear().range([margin.left, plot_dx]), | |
yPos = d3.scaleLinear().range([plot_dy, margin.top]); | |
var svg = d3.select("#chart") | |
.append("svg") | |
.attr("width", svg_dx) | |
.attr("height", svg_dy); | |
d3.csv("logistic_reg_grad_decent.csv", d => { | |
var d_extent_x = d3.extent(d, d => +d.x), | |
d_extent_y = d3.extent(d, d => +d.y); | |
xPos.domain(d_extent_x); | |
yPos.domain(d_extent_y); | |
var axis_x = d3.axisBottom(xPos), | |
axis_y = d3.axisLeft(yPos); | |
svg.append("g") | |
.attr("id", "axis_x") | |
.attr("transform", "translate(0," + (plot_dy + margin.bottom / 2) + ")") | |
.call(axis_x); | |
svg.append("g") | |
.attr("id", "axis_y") | |
.attr("transform", "translate(" + (margin.left / 2) + ", 0)") | |
.call(axis_y); | |
svg.append("g") | |
.selectAll("path") | |
.data(d) | |
.enter() | |
.append("path") | |
.attr("class", d => d.group == "1" ? "pts group1" : "pts group2") | |
.attr("d", d3.symbol().type((d,i) => d.group == "1" ? d3.symbolCircle : d3.symbolCross)) | |
.attr("transform", d => "translate(" + xPos(d.x) + "," + yPos(d.y) + ")") | |
.call(d3.drag() | |
.on("start", dragstarted) | |
.on("drag", dragged)); | |
runGradientDescent(400, 0.0004, -24.0, 0.5, 0.2); | |
}); | |
function dragstarted() { | |
d3.select(this).raise(); | |
} | |
function dragged(d) { | |
var dx = d3.event.sourceEvent.offsetX, | |
dy = d3.event.sourceEvent.offsetY; | |
d3.select(this) | |
.attr("transform", d => "translate(" + dx + "," + dy + ")"); | |
} | |
function sigmoid(z) { | |
var s = 1 / (1 + Math.pow(Math.E, -z)); | |
return s; | |
} | |
function computeGradient(m, y, h, X) { | |
// conversion from octave of grad = (1 / m) * (h - y)' * X; | |
var grad = math.multiply(h.map((h, i) => h - math.subset(y, math.index(i))), X) | |
.map(d => (1 / m) * d); | |
return grad; | |
} | |
function updateParams(form) { | |
var iterationNumber = +form.iterationNumber.value, | |
alpha = +form.alpha.value, | |
theta0 = +form.theta0.value, | |
theta1 = +form.theta1.value, | |
theta2 = +form.theta2.value; | |
// remove previous decision boundary | |
d3.select("#dec_boundary").remove(); | |
runGradientDescent(iterationNumber, alpha, theta0, theta1, theta2); | |
} | |
function runGradientDescent(iterationNumber, alpha, theta0, theta1, theta2) { | |
var coords = [], | |
group = []; | |
d3.selectAll(".pts") | |
.each(function() { | |
var pt = d3.select(this); | |
var xy_re = /\d+.?\d+,\d+.?\d+/; | |
// translated x and y values | |
var xy = pt.attr("transform") | |
.match(xy_re)[0] | |
.split(","); | |
coords.push(xy); | |
// group data | |
group.push(pt.data()[0].group); | |
}); | |
var d = coords.map((coord, i) => { | |
return { "group": group[i], | |
"x": xPos.invert(+coord[0]), | |
"y": yPos.invert(+coord[1]) | |
} | |
}); | |
var d_extent_x = d3.extent(d, pt => +pt.x); | |
var X = d.map(pt => [1, +pt.x, +pt.y]), | |
y = d.map(pt => +pt.group); | |
X = math.matrix(X); | |
y = math.matrix(y); | |
var iteration = 0, | |
m = math.subset(math.size(X), math.index(0)), | |
theta = math.matrix([theta0, theta1, theta2]) | |
var dec_bnd = svg.append("line") | |
.attr("id", "dec_boundary"); | |
var iterate = d3.timer(() => { | |
var h = math.multiply(X, theta).map(z => sigmoid(z)), | |
grad = computeGradient(m, y, h, X); | |
// update theta | |
theta = theta.map((t, i) => t - (alpha * math.subset(grad, math.index(i)))) | |
var theta0 = math.subset(theta, math.index(0)), | |
theta1 = math.subset(theta, math.index(1)), | |
theta2 = math.subset(theta, math.index(2)); | |
dec_bnd.attr("x1",xPos(d_extent_x[0])) | |
.attr("y1",yPos((-1 / theta2) * (theta1 * d_extent_x[0] + theta0))) | |
.attr("x2",xPos(d_extent_x[1])) | |
.attr("y2",yPos((-1 / theta2) * (theta1 * (d_extent_x[1] * .95) + theta0))); | |
if (iteration++ > iterationNumber) { | |
iterate.stop(); | |
} | |
}, 200) | |
} | |
</script> | |
</body> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x | y | group | |
---|---|---|---|
34.62365962451697 | 78.0246928153624 | 0 | |
30.28671076822607 | 43.89499752400101 | 0 | |
35.84740876993872 | 72.90219802708364 | 0 | |
60.18259938620976 | 86.30855209546826 | 1 | |
79.0327360507101 | 75.3443764369103 | 1 | |
45.08327747668339 | 56.3163717815305 | 0 | |
61.10666453684766 | 96.51142588489624 | 1 | |
75.02474556738889 | 46.55401354116538 | 1 | |
76.09878670226257 | 87.42056971926803 | 1 | |
84.43281996120035 | 43.53339331072109 | 1 | |
95.86155507093572 | 38.22527805795094 | 0 | |
75.01365838958247 | 30.60326323428011 | 0 | |
82.30705337399482 | 76.48196330235604 | 1 | |
69.36458875970939 | 97.71869196188608 | 1 | |
39.53833914367223 | 76.03681085115882 | 0 | |
53.9710521485623 | 89.20735013750205 | 1 | |
69.07014406283025 | 52.74046973016765 | 1 | |
67.94685547711617 | 46.67857410673128 | 0 | |
70.66150955499435 | 92.92713789364831 | 1 | |
76.97878372747498 | 47.57596364975532 | 1 | |
67.37202754570876 | 42.83843832029179 | 0 | |
89.67677575072079 | 65.79936592745237 | 1 | |
50.534788289883 | 48.85581152764205 | 0 | |
34.21206097786789 | 44.20952859866288 | 0 | |
77.9240914545704 | 68.9723599933059 | 1 | |
62.27101367004632 | 69.95445795447587 | 1 | |
80.1901807509566 | 44.82162893218353 | 1 | |
93.114388797442 | 38.80067033713209 | 0 | |
61.83020602312595 | 50.25610789244621 | 0 | |
38.78580379679423 | 64.99568095539578 | 0 | |
61.379289447425 | 72.80788731317097 | 1 | |
85.40451939411645 | 57.05198397627122 | 1 | |
52.10797973193984 | 63.12762376881715 | 0 | |
52.04540476831827 | 69.43286012045222 | 1 | |
40.23689373545111 | 71.16774802184875 | 0 | |
54.63510555424817 | 52.21388588061123 | 0 | |
33.91550010906887 | 98.86943574220611 | 0 | |
64.17698887494485 | 80.90806058670817 | 1 | |
74.78925295941542 | 41.57341522824434 | 0 | |
34.1836400264419 | 75.2377203360134 | 0 | |
83.90239366249155 | 56.30804621605327 | 1 | |
51.54772026906181 | 46.85629026349976 | 0 | |
94.44336776917852 | 65.56892160559052 | 1 | |
82.36875375713919 | 40.61825515970618 | 0 | |
51.04775177128865 | 45.82270145776001 | 0 | |
62.22267576120188 | 52.06099194836679 | 0 | |
77.19303492601364 | 70.45820000180959 | 1 | |
97.77159928000232 | 86.7278223300282 | 1 | |
62.07306379667647 | 96.76882412413983 | 1 | |
91.56497449807442 | 88.69629254546599 | 1 | |
79.94481794066932 | 74.16311935043758 | 1 | |
99.2725269292572 | 60.99903099844988 | 1 | |
90.54671411399852 | 43.39060180650027 | 1 | |
34.52451385320009 | 60.39634245837173 | 0 | |
50.2864961189907 | 49.80453881323059 | 0 | |
49.58667721632031 | 59.80895099453265 | 0 | |
97.64563396007767 | 68.86157272420604 | 1 | |
32.57720016809309 | 95.59854761387875 | 0 | |
74.24869136721598 | 69.82457122657193 | 1 | |
71.79646205863379 | 78.45356224515052 | 1 | |
75.3956114656803 | 85.75993667331619 | 1 | |
35.28611281526193 | 47.02051394723416 | 0 | |
56.25381749711624 | 39.26147251058019 | 0 | |
30.05882244669796 | 49.59297386723685 | 0 | |
44.66826172480893 | 66.45008614558913 | 0 | |
66.56089447242954 | 41.09209807936973 | 0 | |
40.45755098375164 | 97.53518548909936 | 1 | |
49.07256321908844 | 51.88321182073966 | 0 | |
80.27957401466998 | 92.11606081344084 | 1 | |
66.74671856944039 | 60.99139402740988 | 1 | |
32.72283304060323 | 43.30717306430063 | 0 | |
64.0393204150601 | 78.03168802018232 | 1 | |
72.34649422579923 | 96.22759296761404 | 1 | |
60.45788573918959 | 73.09499809758037 | 1 | |
58.84095621726802 | 75.85844831279042 | 1 | |
99.82785779692128 | 72.36925193383885 | 1 | |
47.26426910848174 | 88.47586499559782 | 1 | |
50.45815980285988 | 75.80985952982456 | 1 | |
60.45555629271532 | 42.50840943572217 | 0 | |
82.22666157785568 | 42.71987853716458 | 0 | |
88.9138964166533 | 69.80378889835472 | 1 | |
94.83450672430196 | 45.69430680250754 | 1 | |
67.31925746917527 | 66.58935317747915 | 1 | |
57.23870631569862 | 59.51428198012956 | 1 | |
80.36675600171273 | 90.96014789746954 | 1 | |
68.46852178591112 | 85.59430710452014 | 1 | |
42.0754545384731 | 78.84478600148043 | 0 | |
75.47770200533905 | 90.42453899753964 | 1 | |
78.63542434898018 | 96.64742716885644 | 1 | |
52.34800398794107 | 60.76950525602592 | 0 | |
94.09433112516793 | 77.15910509073893 | 1 | |
90.44855097096364 | 87.50879176484702 | 1 | |
55.48216114069585 | 35.57070347228866 | 0 | |
74.49269241843041 | 84.84513684930135 | 1 | |
89.84580670720979 | 45.35828361091658 | 1 | |
83.48916274498238 | 48.38028579728175 | 1 | |
42.2617008099817 | 87.10385094025457 | 1 | |
99.31500880510394 | 68.77540947206617 | 1 | |
55.34001756003703 | 64.9319380069486 | 1 | |
74.77589300092767 | 89.52981289513276 | 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment