Skip to content

Instantly share code, notes, and snippets.

@georgeyjm
Last active August 31, 2018 14:18
Show Gist options
  • Save georgeyjm/5d44eb6a913b1b0dfe8108abfee6ce5e to your computer and use it in GitHub Desktop.
Save georgeyjm/5d44eb6a913b1b0dfe8108abfee6ce5e to your computer and use it in GitHub Desktop.
Interactive Logistic Regression with TensorFlow.js + p5.js
<!DOCTYPE html>
<html>
<head>
<meta charset='UTF-8'>
<meta http-equiv='X-UA-Compatible' content='IE=edge'>
<meta name='viewport' content='width=device-width, initial-scale=1'>
<style>
* {margin: 0}
</style>
<script src='https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]'></script>
<script src='https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.6.0/p5.js'></script>
<script src='https://cdnjs.cloudflare.com/ajax/libs/jquery/3.3.1/jquery.slim.min.js'></script>
<script src='sketch.js'></script>
<title>Interactive Logistic Regression</title>
</head>
</html>
let x_vals = [];
let y_vals = [];
let theta;
const learningRate = 4;
const optimizer = tf.train.adagrad(learningRate);
function setup() {
$('body').on('contextmenu', 'canvas', e => false);
createCanvas(displayWidth, displayHeight);
theta = tf.variable(tf.tensor2d([0, 0, 0], [1, 3]));
}
function mousePressed() {
let mx = map(mouseX, 0, width, -1, 1);
let my = map(mouseY, 0, height, 1, -1);
x_vals.push([1, mx, my]);
y_vals.push(mouseButton === LEFT ? 0 : 1);
}
function f(x) {
return x.matMul(theta.transpose()).sigmoid().as1D();
}
function draw() {
background('#EEEEEE');
if (x_vals.length === 0) return;
tf.tidy(() => {
// Perform optimization
const x = tf.tensor2d(x_vals);
const y = tf.tensor1d(y_vals);
optimizer.minimize(() => tf.losses.logLoss(y, f(x)));
// Color the two classes
// I am using a pretty bad method to color in the two classes
// There is definitely a better way
let theta_val = theta.dataSync();
let ref1 = (theta_val[2] - theta_val[0]) / theta_val[1];
let ref2 = -(theta_val[2] + theta_val[0]) / theta_val[1];
let left = Math.round(f(tf.tensor2d([1, ref2 - 1, 1], [1, 3])).dataSync()[0]);
ref1 = map(ref1, -1, 1, 0, width);
ref2 = map(ref2, -1, 1, 0, width);
noStroke();
fill(left === 0 ? 'rgba(37, 130, 191, 0.25)' : 'rgba(240, 154, 66, 0.25)');
beginShape();
vertex(0, 0);
vertex(0, height);
vertex(ref1, height);
vertex(ref2, 0);
endShape();
fill(left === 1 ? 'rgba(37, 130, 191, 0.25)' : 'rgba(240, 154, 66, 0.25)');
beginShape();
vertex(width, 0);
vertex(width, height);
vertex(ref1, height);
vertex(ref2, 0);
endShape();
// Draw decision boundary
let line_y = [-1, 1].map(x => map(-(theta_val[0] + theta_val[1] * x) / theta_val[2], -1, 1, height, 0));
stroke(0);
strokeWeight(2);
line(0, line_y[0], width, line_y[1]);
});
// Draw points
strokeWeight(10);
for (let i = 0; i < x_vals.length; i++) {
let px = map(x_vals[i][1], -1, 1, 0, width);
let py = map(x_vals[i][2], -1, 1, height, 0);
stroke(y_vals[i] === 0 ? '#2582BF' : '#F09A42');
point(px, py);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment