Created
May 8, 2015 00:04
-
-
Save zomux/cee5c4878c9256ecf6c0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
learn: function(r1) { | |
// perform an update on Q function | |
if(!(this.r0 == null) && this.alpha > 0) { | |
// learn from this tuple to get a sense of how "surprising" it is to the agent | |
var tderror = this.learnFromTuple(this.s0, this.a0, this.r0, this.s1, this.a1); | |
this.tderror = tderror; // a measure of surprise | |
// decide if we should keep this experience in the replay | |
if(this.t % this.experience_add_every === 0) { | |
this.exp[this.expi] = [this.s0, this.a0, this.r0, this.s1, this.a1]; | |
this.expi += 1; | |
if(this.expi > this.experience_size) { this.expi = 0; } // roll over when we run out | |
} | |
this.t += 1; | |
// sample some additional experience from replay memory and learn from it | |
for(var k=0;k<this.learning_steps_per_iteration;k++) { | |
var ri = randi(0, this.exp.length); // todo: priority sweeps? | |
var e = this.exp[ri]; | |
this.learnFromTuple(e[0], e[1], e[2], e[3], e[4]) | |
} | |
} | |
this.r0 = r1; // store for next update | |
}, | |
learnFromTuple: function(s0, a0, r0, s1, a1) { | |
// want: Q(s,a) = r + gamma * max_a' Q(s',a') | |
// compute the target Q value | |
var tmat = this.forwardQ(this.net, s1, false); | |
var qmax = r0 + this.gamma * tmat.w[R.maxi(tmat.w)]; | |
// now predict | |
var pred = this.forwardQ(this.net, s0, true); | |
var tderror = pred.w[a0] - qmax; | |
var clamp = this.tderror_clamp; | |
if(Math.abs(tderror) > clamp) { // huber loss to robustify | |
if(tderror > clamp) tderror = clamp; | |
if(tderror < -clamp) tderror = -clamp; | |
} | |
pred.dw[a0] = tderror; | |
this.lastG.backward(); // compute gradients on net params | |
// update net | |
R.updateNet(this.net, this.alpha); | |
return tderror; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment