Created
October 31, 2017 12:26
-
-
Save mattbushell/2210f6f9032ef6e131c2dcb1b914e239 to your computer and use it in GitHub Desktop.
Example of A3C using LSTM and a custom implementation of the MDP API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.rl4j; | |
import java.io.IOException; | |
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscrete; | |
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscreteDense; | |
import org.deeplearning4j.rl4j.mdp.MDP; | |
import org.deeplearning4j.rl4j.mdp.toy.HardDeteministicToy; | |
import org.deeplearning4j.rl4j.mdp.toy.HardToyState; | |
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactorySeparateStdDense; | |
import org.deeplearning4j.rl4j.policy.ACPolicy; | |
import org.deeplearning4j.rl4j.space.DiscreteSpace; | |
import org.deeplearning4j.rl4j.util.DataManager; | |
/** | |
* @author Matt Bushell 31/10/2017 | |
* | |
* example for LSTM A3C on MDP Toy | |
* | |
*/ | |
public class A3CCustomMDP { | |
private static A3CDiscrete.A3CConfiguration CUSTOM_A3C = | |
new A3CDiscrete.A3CConfiguration( | |
123, //Random seed | |
20, //Max step By epoch | |
20*10*8, //Max step 10 epochs per thread | |
8, //Number of threads | |
5, //t_max | |
10, //num step noop warmup | |
0.01, //reward scaling | |
0.99, //gamma | |
10.0 //td-error clipping | |
); | |
private static final ActorCriticFactorySeparateStdDense.Configuration CUSTOM_NET_A3C = ActorCriticFactorySeparateStdDense.Configuration | |
.builder().useLSTM(true) | |
.learningRate(1e-2) | |
.l2(0) | |
.numHiddenNodes(16) | |
.numLayer(4).build(); | |
public static void main(String[] args) throws IOException { | |
init(); | |
} | |
public static void init() throws IOException { | |
//record the training data in rl4j-data in a new folder | |
DataManager manager = new DataManager(true); | |
//define the mdp; your own MDPs can be based on this toy | |
MDP<HardToyState, Integer, DiscreteSpace> mdp = new HardDeteministicToy(); | |
//define the training | |
A3CDiscreteDense<HardToyState> a3c = new A3CDiscreteDense<HardToyState>(mdp, CUSTOM_NET_A3C, CUSTOM_A3C, manager); | |
//start the training | |
a3c.train(); | |
ACPolicy<HardToyState> pol = a3c.getPolicy(); | |
pol.save("/tmp/val1.ac3model", "/tmp/pol1.ac3model"); | |
//close the mdp (http connection) | |
mdp.close(); | |
//reload the policy, will be equal to "pol", but without the randomness | |
ACPolicy<HardToyState> pol2 = ACPolicy.load("/tmp/val1.ac3model", "/tmp/pol1.ac3model"); | |
//play it | |
mdp.reset(); | |
pol2.play(mdp); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment