Created
May 15, 2012 12:19
-
-
Save xoba/2701331 to your computer and use it in GitHub Desktop.
sgdmf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$(function() { | |
$('#title').html("sgdmf"); | |
(function() { | |
var i = []; | |
i.push("<p>"); | |
i.push("This is a demo of stochastic gradient descent applied to matrix factorization."); | |
i.push("'lambda' penalizes model complexity (in terms of rank), so"); | |
i.push("smaller lambda's yield more accuracy."); | |
i.push("You can drag/drop your own images here too,"); | |
i.push("although obviously the technique is more appropriate for sparse matrices or non-imaging purposes."); | |
i.push("See <a href='http://xoba-public.s3.amazonaws.com/2a027d92d63489e5042bc1bcdcfa42b2.pdf'>Recht et al.</a> and <a href='http://xoba-public.s3.amazonaws.com/752b8adf3379ea424554f7304842edbd.pdf'>Koren et al.</a> for related work."); | |
i.push("Only <span class='hl'>CHROME</span> browser is fast enough for this."); | |
i.push("See <a target='_blank' href='https://gist.github.com/xxxxxxx'>gist</a> for source."); | |
i.push("Also try <a target='_blank' href='mike.png'>mike.png</a>."); | |
i.push("</p>"); | |
$('#main').html(i.join(" ")); | |
})(); | |
function fmt(nStr) { | |
nStr += ''; | |
x = nStr.split('.'); | |
x1 = x[0]; | |
x2 = x.length > 1 ? '.' + x[1] : ''; | |
var rgx = /(\d+)(\d{3})/; | |
while (rgx.test(x1)) { | |
x1 = x1.replace(rgx, '$1' + ',' + '$2'); | |
} | |
return x1 + x2; | |
}; | |
function uuid() { | |
var x = function() { | |
return (((1+Math.random())*0x10000)|0).toString(16).substring(1); | |
}; | |
return (x()+x()+"-"+x()+"-"+x()+"-"+x()+"-"+x()+x()+x()); | |
} | |
var db = function(x) { | |
return 10 * Math.log(x) / Math.log(10); | |
} | |
var idb = function(x) { | |
return Math.pow(10,x/10); | |
} | |
function Status() { | |
var rows = []; | |
var self = this; | |
var add = function(title,id) { | |
rows.push("<tr><td>"+title+":</td><td class='right' id='"+id+"'></td></tr>"); | |
self[id] = function(t) { | |
$('#'+id).html(t); | |
}; | |
}; | |
add("time per frame","time"); | |
add("frobenius norm","norm"); | |
add("learning rate","gamma"); | |
add("maximum rank","rank"); | |
$('#main').append("<table>"+rows.join("")+"</table>"); | |
}; | |
function Slider(name,min,max,initialValue,cb) { | |
if ($('#sliders').length == 0) { | |
$('#main').append("<table id='sliders'></table>"); | |
} | |
var step = (max-min)/300; | |
var id = uuid(); | |
$('#sliders').append("<tr><td>"+name+":</td><td style='width:608px'><input type='range' value='"+initialValue+"' step='"+step+"' id='S"+id+"' min='"+min+"' max='"+max+"'/></td><td id='T"+id+"'></td></tr>"); | |
var update = function() { | |
var v = $('#S' + id).val(); | |
$('#T'+id).html(cb(v)); | |
}; | |
$('#S' + id).change(update); | |
$('#T' + id).html(cb(initialValue)); | |
}; | |
var luminance = function(r,g,b) { | |
return Math.round(0.3 * r + 0.59 * g + 0.11 * b); | |
}; | |
function Vector(w) { | |
var a = new Float32Array(new ArrayBuffer(4*w)); | |
this.dims = [w]; | |
this.get = function(i) { | |
return a[i]; | |
}; | |
this.set = function(i,v) { | |
a[i] = v; | |
}; | |
this.inc = function(i,v) { | |
a[i] += v; | |
} | |
}; | |
function Matrix(w,h) { | |
var a = new Float32Array(new ArrayBuffer(4*w*h)); | |
this.dims = [w,h]; | |
this.get = function(i,j) { | |
return a[i+j*w]; | |
}; | |
this.set = function(i,j,v) { | |
a[i+j*w] = v; | |
}; | |
this.inc = function(i,j,v) { | |
a[i+j*w] += v; | |
}; | |
}; | |
var imageToMatrix = function(img) { | |
var pix = img.data; | |
var w = img.width; | |
var h = img.height; | |
var mat = new Matrix(w,h); | |
for (var i=0; i<w; i++) { | |
for (var j=0; j<h; j++) { | |
var index = 4*(i+j*w); | |
mat.set(i,j,luminance(pix[index],pix[index+1],pix[index+2])); | |
} | |
} | |
return mat; | |
}; | |
var drawMatrix = function(ctx,img,mat) { | |
var pix = img.data; | |
var h = img.height; | |
var w = img.width; | |
var f = 0; | |
for (var i=0; i<w; i++) { | |
for (var j=0; j<h; j++) { | |
var f0 = mat.get(i,j); | |
var index = 4*(i + w*j); | |
pix[index] = f0; | |
pix[index+1] = f0; | |
pix[index+2] = f0; | |
f += (f0*f0); | |
} | |
} | |
ctx.putImageData(img,0,0); | |
return f; | |
} | |
var rand = function(n) { | |
return Math.floor(Math.random() * n); | |
} | |
var compute = function(state) { | |
var f = function() { | |
var iterate = function() { | |
setTimeout(f,state.period); | |
}; | |
var startTime = new Date().getTime(); | |
var r = state.r; | |
var original = state.original; | |
var left = state.left; | |
var right = state.right; | |
var g = state.gamma; | |
var lambda = state.lambda; | |
var lg = new Vector(r); | |
var rg = new Vector(r); | |
var h = state.h; | |
var w = state.w; | |
var la = left.data; | |
var ra = right.data; | |
for (var iter=0; iter<state.niter; iter++) { | |
var i = rand(w); | |
var j = rand(h); | |
var pred = 0; | |
for (var k=0; k<r; k++) { | |
pred += left.get(i,k) * right.get(k,j); | |
} | |
var error = pred - original.get(i,j); | |
for (var k=0; k<r; k++) { | |
lg.set(k, error * right.get(k,j) + lambda * left.get(i,k)/h); | |
rg.set(k, error * left.get(i,k) + lambda * right.get(k,j)/w); | |
} | |
for (var k=0; k<r; k++) { | |
left.inc(i,k, - g * lg.get(k)); | |
right.inc(k,j, - g * rg.get(k)); | |
} | |
} | |
var f2 = state.draw(); | |
state.stats.norm(fmt((f2/1000000).toFixed(0))); | |
var endTime = new Date().getTime(); | |
var time = endTime-startTime; | |
state.stats.time(time + " ms"); | |
iterate(); | |
}; | |
return f; | |
} | |
var randmat = function(n,m) { | |
var mat = new Matrix(n,m); | |
var count=0; | |
for (var i=0; i<n; i++) { | |
for (var j=0; j<m; j++) { | |
mat.set(i,j,10*Math.random()); | |
} | |
} | |
return mat; | |
}; | |
var run = function(img) { | |
return function() { | |
var h = img.height; | |
var w = img.width; | |
var stats = new Status(); | |
var drawCanvas = function(id,w,h) { | |
$('#main').append("<canvas id='"+id+"' height='"+h+"px' width='"+w+"px'></canvas>"); | |
} | |
drawCanvas('c1',w,h); | |
drawCanvas('c2',w,h); | |
var c1 = $('#c1')[0].getContext("2d"); | |
var c2 = $('#c2')[0].getContext("2d"); | |
c1.drawImage(img, 0,0); | |
var imageData = c1.getImageData(0,0,w,h); | |
var rank = 10; | |
var f = new Matrix(w,h); | |
state = { | |
h: h, | |
w: w, | |
r: rank, | |
stats:stats, | |
gamma:0.001, | |
lambda:1000, | |
period:100, | |
niter:5000, | |
original: imageToMatrix(imageData), | |
left: randmat(w,rank), | |
right: randmat(rank,h), | |
draw: function() { | |
var left = this.left; | |
var right = this.right; | |
for (var i=0; i<w; i++) { | |
for (var j=0; j<h; j++) { | |
var p = 0; | |
for (var k=0; k<rank; k++) { | |
p += left.get(i,k) * right.get(k,j); | |
} | |
f.set(i,j,p); | |
} | |
} | |
return drawMatrix(c2,imageData,f); | |
} | |
}; | |
stats.gamma(state.gamma); | |
stats.rank(state.r); | |
new Slider('lambda',db(100),db(50000),db(state.lambda),function(v) { | |
var iv = idb(v) | |
state.lambda = idb(v); | |
return fmt(iv.toFixed(0)); | |
}); | |
new Slider('iterations',db(100),db(100000),db(state.niter),function(v) { | |
var iv = idb(v) | |
state.niter = idb(v); | |
return fmt(iv.toFixed(0)) + " per frame"; | |
}); | |
new Slider('frame period',db(50),db(3000),db(state.period),function(v) { | |
var iv = idb(v) | |
state.period = idb(v); | |
return fmt(iv.toFixed(0)) + " ms"; | |
}); | |
setTimeout(compute(state),state.period); | |
$('#main').bind('drop',function(e) { | |
e.preventDefault(); | |
var loadImage = function(uri) { | |
var img = new Image(); | |
img.onload = function() { | |
var c1 = $('#c1')[0].getContext("2d"); | |
c1.drawImage(img, 0,0); | |
var imageData = c1.getImageData(0,0,w,h); | |
state.original = imageToMatrix(imageData); | |
}; | |
img.src = uri; | |
}; | |
var dt = e.originalEvent.dataTransfer; | |
var uri = dt.getData("text/uri-list"); | |
if (uri == undefined) { | |
var file = dt.files[0]; | |
var reader = new FileReader(); | |
reader.onload = function (event) { | |
loadImage(event.currentTarget.result); | |
}; | |
reader.readAsDataURL(file); | |
} else { | |
loadImage(uri); | |
} | |
}); | |
} | |
}; | |
var img = new Image(); | |
img.onload = run(img); | |
img.src = 'mit_logo.png'; | |
}); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment