Skip to content

Instantly share code, notes, and snippets.

@mattbushell
Created October 31, 2017 12:26
Show Gist options
  • Save mattbushell/2210f6f9032ef6e131c2dcb1b914e239 to your computer and use it in GitHub Desktop.
Save mattbushell/2210f6f9032ef6e131c2dcb1b914e239 to your computer and use it in GitHub Desktop.
Example of A3C using LSTM and a custom implementation of the MDP API
package org.deeplearning4j.examples.rl4j;
import java.io.IOException;
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscrete;
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscreteDense;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.mdp.toy.HardDeteministicToy;
import org.deeplearning4j.rl4j.mdp.toy.HardToyState;
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactorySeparateStdDense;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManager;
/**
* @author Matt Bushell 31/10/2017
*
* example for LSTM A3C on MDP Toy
*
*/
public class A3CCustomMDP {
private static A3CDiscrete.A3CConfiguration CUSTOM_A3C =
new A3CDiscrete.A3CConfiguration(
123, //Random seed
20, //Max step By epoch
20*10*8, //Max step 10 epochs per thread
8, //Number of threads
5, //t_max
10, //num step noop warmup
0.01, //reward scaling
0.99, //gamma
10.0 //td-error clipping
);
private static final ActorCriticFactorySeparateStdDense.Configuration CUSTOM_NET_A3C = ActorCriticFactorySeparateStdDense.Configuration
.builder().useLSTM(true)
.learningRate(1e-2)
.l2(0)
.numHiddenNodes(16)
.numLayer(4).build();
public static void main(String[] args) throws IOException {
init();
}
public static void init() throws IOException {
//record the training data in rl4j-data in a new folder
DataManager manager = new DataManager(true);
//define the mdp; your own MDPs can be based on this toy
MDP<HardToyState, Integer, DiscreteSpace> mdp = new HardDeteministicToy();
//define the training
A3CDiscreteDense<HardToyState> a3c = new A3CDiscreteDense<HardToyState>(mdp, CUSTOM_NET_A3C, CUSTOM_A3C, manager);
//start the training
a3c.train();
ACPolicy<HardToyState> pol = a3c.getPolicy();
pol.save("/tmp/val1.ac3model", "/tmp/pol1.ac3model");
//close the mdp (http connection)
mdp.close();
//reload the policy, will be equal to "pol", but without the randomness
ACPolicy<HardToyState> pol2 = ACPolicy.load("/tmp/val1.ac3model", "/tmp/pol1.ac3model");
//play it
mdp.reset();
pol2.play(mdp);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment