Skip to content

Instantly share code, notes, and snippets.

@lacic
Created April 8, 2017 12:41
Show Gist options
  • Save lacic/3ada1ae9b5abbc400acc28411b4a3df0 to your computer and use it in GitHub Desktop.
Save lacic/3ada1ae9b5abbc400acc28411b4a3df0 to your computer and use it in GitHub Desktop.
public class ModelTrainer {
private String modelPath;
private String modelName;
private final static Integer WORKERS = 4;
private final static Integer PREFETCH_BUFFER = 24;
private final static Integer N_EPOCHS = 250;
private final static Integer BATCH_SIZE = 300;
private final static Integer CLASS_COUNT = 20;
private static final String BASE_DIR = "... some path";
public static void main(String[] args) {
CudaEnvironment.getInstance().getConfiguration()
.allowMultiGPU(true)
.setMaximumDeviceCacheableLength(1024 * 1024 * 1024L)
.setMaximumDeviceCache(6L * 1024 * 1024 * 1024L)
.setMaximumHostCacheableLength(1024 * 1024 * 1024L)
.setMaximumHostCache(6L * 1024 * 1024 * 1024L)
.setMaximumGridSize(512)
.setMaximumBlockSize(512)
.allowCrossDeviceAccess(true);
String modelName = numberOfLayers + "_layer_class_model_D1";
String modelPath = BASE_DIR + "/train/model/";
ModelTrainer modelTrainer = new ModelTrainer(modelPath, modelName);
String t2Features = BASE_DIR + "/train/features/";
String t2Labels = BASE_DIR + "/train/labels/";
String validateFeatures = BASE_DIR + "/validation/features/";
String validateLabels = BASE_DIR + "/validation/labels/";
int trainMaxFileId = 15548;
int validationMaxFileId = 10365;
int numberOfLayers = 1;
modelTrainer.trainConfigNetwork(
t2Features,
t2Labels,
validateFeatures,
validateLabels,
0,
trainMaxFileId,
validationMaxFileId,
numberOfLayers);
}
public ModelTrainer(String modelPath, String modelName) {
this.modelPath = modelPath;
this.modelName = modelName;
}
public void trainConfigNetwork(
String trainFeatures,
String trainLabels,
String validationFeatures,
String validationLabels,
Integer minFileId,
Integer trainMaxFileId,
Integer validationMaxFileId,
Integer numberOfLayers) {
MultiLayerConfiguration config = createConfiguration(numberOfLayers);
MultiLayerNetwork net = createNetwork(config);
trainNetworkWithEarlyStop(trainFeatures, trainLabels, validationFeatures, validationLabels, minFileId, trainMaxFileId, validationMaxFileId, net);
}
private MultiLayerConfiguration createConfiguration(Integer numberOfLayers) {
double l2 = 0.038266652122898336;
double learningRate = 0.06122531153512233;
double rmsDecay = 0.31881865683702093;
double dropOut = 0.3098825217991597;
double clipThreshold = 55.54055480346259;
int hiddenLayerWidth = 225;
NeuralNetConfiguration.Builder configBuilder = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.seed(12345)
.regularization(true)
.l2(l2)
.learningRate(learningRate)
.rmsDecay(rmsDecay)
.dropOut(dropOut)
.updater(Updater.RMSPROP)
.weightInit(WeightInit.XAVIER)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(clipThreshold);
NeuralNetConfiguration.ListBuilder listBuilder = configBuilder.list();
// first difference, for rnns we need to use GravesLSTM.Builder
for (int i = 0; i < numberOfLayers; i++) {
GravesLSTM.Builder hiddenLayerBuilder = new GravesLSTM.Builder();
hiddenLayerBuilder.nIn(i == 0 ? CLASS_COUNT: hiddenLayerWidth);
hiddenLayerBuilder.nOut(hiddenLayerWidth);
hiddenLayerBuilder.activation(Activation.TANH);
listBuilder.layer(i, hiddenLayerBuilder.build());
}
// we need to use RnnOutputLayer for our RNN
RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT);
// softmax normalizes the output neurons, the sum of all outputs is 1
// this is required for our sampleFromDistribution-function
outputLayerBuilder.activation(Activation.SOFTMAX);
outputLayerBuilder.nIn(hiddenLayerWidth);
outputLayerBuilder.nOut(CLASS_COUNT);
listBuilder.layer(numberOfLayers, outputLayerBuilder.build());
// finish builder
listBuilder.pretrain(false);
listBuilder.backprop(true);
MultiLayerConfiguration conf = listBuilder.build();
return conf;
}
private void trainNetworkWithEarlyStop(
String trainFeatures,
String trainLabels,
String validationFeatures,
String validationLabels,
Integer minFileId,
Integer trainMaxFileId,
Integer validationMaxFileId,
MultiLayerNetwork net) {
DataSetIterator trainingData = getDataIterator(trainFeatures, trainLabels, minFileId, trainMaxFileId);
DataSetIterator testData = getDataIterator(validationFeatures, validationLabels, minFileId, validationMaxFileId);
EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(
new MaxEpochsTerminationCondition(N_EPOCHS),
new ScoreImprovementEpochTerminationCondition(20)
)
.scoreCalculator(new DataSetLossCalculator(testData, true))
.evaluateEveryNEpochs(1)
.build();
int averagingFrequency = 3;
EarlyStoppingParallelTrainer<MultiLayerNetwork> wrapper =
new EarlyStoppingParallelTrainer<>(
esConf,
net,
trainingData, // use SequenceRecordReaderDataSetIterator
null, // MultiDataSetIterator is null
WORKERS,
PREFETCH_BUFFER,
averagingFrequency);
//Conduct early stopping training
EarlyStoppingResult<MultiLayerNetwork> result = wrapper.fit();
System.out.println("Termination reason: " + result.getTerminationReason());
System.out.println("Termination details: " + result.getTerminationDetails());
System.out.println("Total epochs: " + result.getTotalEpochs());
System.out.println("Best epoch number: " + result.getBestModelEpoch());
System.out.println("Score at best epoch: " + result.getBestModelScore());
storeNetworkModel(result.getBestModel());
System.out.println("Successfully created and trained the NN model.");
}
private void storeNetworkModel(MultiLayerNetwork net) {
if (net == null)
return;
try {
File storedModel = new File(modelPath, modelName);
// deletes previously created model
storedModel.delete();
FileOutputStream fos = new FileOutputStream(storedModel);
ModelSerializer.writeModel(net, fos, true);
} catch (IOException e) {
System.out.println("Failed storing model [" + modelName + "].");
}
}
private DataSetIterator getDataIterator(String trainFeaturesPath, String trainLabelsPath, Integer minFileId, Integer maxFileId) {
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
try {
File featuresDirTrain = new File(trainFeaturesPath);
File labelsDirTrain = new File(trainLabelsPath);
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", minFileId, maxFileId));
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", minFileId, maxFileId));
} catch (Exception e) {
System.out.println("Error when creating training data for RNN.");
}
DataSetIterator dataIterator = null;
dataIterator = new SequenceRecordReaderDataSetIterator(
trainFeatures,
trainLabels,
BATCH_SIZE,
CLASS_COUNT,
false,
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
return dataIterator;
}
private MultiLayerNetwork createNetwork(MultiLayerConfiguration conf) {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1000));
return net;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment