Skip to content

Instantly share code, notes, and snippets.

@dstokes
Last active August 29, 2015 14:01
Show Gist options
  • Save dstokes/62452b0b483662a71871 to your computer and use it in GitHub Desktop.
Save dstokes/62452b0b483662a71871 to your computer and use it in GitHub Desktop.
// averaged scores across layer nets
{ '0': 3.7868791304080087,
'1': 1.2131208695919908,
BYTES_PER_ELEMENT: 8,
get: [Function: get],
set: [Function: set],
slice: [Function: slice],
subarray: [Function: subarray],
buffer:
{ '0': 252,
'1': 14,
'2': 24,
'3': 73,
'4': 135,
'5': 75,
'6': 14,
'7': 64,
'8': 6,
'9': 226,
'10': 207,
'11': 109,
'12': 241,
'13': 104,
'14': 243,
'15': 63,
slice: [Function: slice],
byteLength: 16 },
length: 2,
byteOffset: 0,
byteLength: 16 }
// stats object for prediction
{ maxi: 0,
maxv: 3.7868791304080087,
mini: 1,
minv: 1.2131208695919908,
dv: 2.573758260816018 }
var maxmin = function(w) {
if(w.length === 0) { return {}; } // ... ;s
var maxv = w[0];
var minv = w[0];
var maxi = 0;
var mini = 0;
var n = w.length;
for(var i=1;i<n;i++) {
if(w[i] > maxv) { maxv = w[i]; maxi = i; }
if(w[i] < minv) { minv = w[i]; mini = i; }
}
return {maxi: maxi, maxv: maxv, mini: mini, minv: minv, dv:maxv-minv};
}
// returns prediction scores for given test data point, as Vol
// uses an averaged prediction from the best ensemble_size models
// x is a Vol.
predict_soft: function(data) {
// forward prop the best networks
// and accumulate probabilities at last layer into a an output Vol
var nv = Math.min(this.ensemble_size, this.evaluated_candidates.length);
if(nv === 0) { return new convnetjs.Vol(0,0,0); } // not sure what to do here? we're not ready yet
var xout, n;
for(var j=0;j<nv;j++) {
var net = this.evaluated_candidates[j].net;
var x = net.forward(data);
if(j===0) {
xout = x;
n = x.w.length;
} else {
// add it on
for(var d=0;d<n;d++) {
xout.w[d] += x.w[d];
}
}
}
// produce average
for(var d=0;d<n;d++) {
xout.w[d] /= n;
}
return xout;
},
predict: function(data) {
var xout = this.predict_soft(data);
if(xout.w.length !== 0) {
var stats = maxmin(xout.w);
var predicted_label = stats.maxi;
} else {
var predicted_label = -1; // error out
}
return predicted_label;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment