Created
April 8, 2017 12:41
-
-
Save lacic/3ada1ae9b5abbc400acc28411b4a3df0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class ModelTrainer { | |
private String modelPath; | |
private String modelName; | |
private final static Integer WORKERS = 4; | |
private final static Integer PREFETCH_BUFFER = 24; | |
private final static Integer N_EPOCHS = 250; | |
private final static Integer BATCH_SIZE = 300; | |
private final static Integer CLASS_COUNT = 20; | |
private static final String BASE_DIR = "... some path"; | |
public static void main(String[] args) { | |
CudaEnvironment.getInstance().getConfiguration() | |
.allowMultiGPU(true) | |
.setMaximumDeviceCacheableLength(1024 * 1024 * 1024L) | |
.setMaximumDeviceCache(6L * 1024 * 1024 * 1024L) | |
.setMaximumHostCacheableLength(1024 * 1024 * 1024L) | |
.setMaximumHostCache(6L * 1024 * 1024 * 1024L) | |
.setMaximumGridSize(512) | |
.setMaximumBlockSize(512) | |
.allowCrossDeviceAccess(true); | |
String modelName = numberOfLayers + "_layer_class_model_D1"; | |
String modelPath = BASE_DIR + "/train/model/"; | |
ModelTrainer modelTrainer = new ModelTrainer(modelPath, modelName); | |
String t2Features = BASE_DIR + "/train/features/"; | |
String t2Labels = BASE_DIR + "/train/labels/"; | |
String validateFeatures = BASE_DIR + "/validation/features/"; | |
String validateLabels = BASE_DIR + "/validation/labels/"; | |
int trainMaxFileId = 15548; | |
int validationMaxFileId = 10365; | |
int numberOfLayers = 1; | |
modelTrainer.trainConfigNetwork( | |
t2Features, | |
t2Labels, | |
validateFeatures, | |
validateLabels, | |
0, | |
trainMaxFileId, | |
validationMaxFileId, | |
numberOfLayers); | |
} | |
public ModelTrainer(String modelPath, String modelName) { | |
this.modelPath = modelPath; | |
this.modelName = modelName; | |
} | |
public void trainConfigNetwork( | |
String trainFeatures, | |
String trainLabels, | |
String validationFeatures, | |
String validationLabels, | |
Integer minFileId, | |
Integer trainMaxFileId, | |
Integer validationMaxFileId, | |
Integer numberOfLayers) { | |
MultiLayerConfiguration config = createConfiguration(numberOfLayers); | |
MultiLayerNetwork net = createNetwork(config); | |
trainNetworkWithEarlyStop(trainFeatures, trainLabels, validationFeatures, validationLabels, minFileId, trainMaxFileId, validationMaxFileId, net); | |
} | |
private MultiLayerConfiguration createConfiguration(Integer numberOfLayers) { | |
double l2 = 0.038266652122898336; | |
double learningRate = 0.06122531153512233; | |
double rmsDecay = 0.31881865683702093; | |
double dropOut = 0.3098825217991597; | |
double clipThreshold = 55.54055480346259; | |
int hiddenLayerWidth = 225; | |
NeuralNetConfiguration.Builder configBuilder = new NeuralNetConfiguration.Builder() | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.iterations(1) | |
.seed(12345) | |
.regularization(true) | |
.l2(l2) | |
.learningRate(learningRate) | |
.rmsDecay(rmsDecay) | |
.dropOut(dropOut) | |
.updater(Updater.RMSPROP) | |
.weightInit(WeightInit.XAVIER) | |
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | |
.gradientNormalizationThreshold(clipThreshold); | |
NeuralNetConfiguration.ListBuilder listBuilder = configBuilder.list(); | |
// first difference, for rnns we need to use GravesLSTM.Builder | |
for (int i = 0; i < numberOfLayers; i++) { | |
GravesLSTM.Builder hiddenLayerBuilder = new GravesLSTM.Builder(); | |
hiddenLayerBuilder.nIn(i == 0 ? CLASS_COUNT: hiddenLayerWidth); | |
hiddenLayerBuilder.nOut(hiddenLayerWidth); | |
hiddenLayerBuilder.activation(Activation.TANH); | |
listBuilder.layer(i, hiddenLayerBuilder.build()); | |
} | |
// we need to use RnnOutputLayer for our RNN | |
RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT); | |
// softmax normalizes the output neurons, the sum of all outputs is 1 | |
// this is required for our sampleFromDistribution-function | |
outputLayerBuilder.activation(Activation.SOFTMAX); | |
outputLayerBuilder.nIn(hiddenLayerWidth); | |
outputLayerBuilder.nOut(CLASS_COUNT); | |
listBuilder.layer(numberOfLayers, outputLayerBuilder.build()); | |
// finish builder | |
listBuilder.pretrain(false); | |
listBuilder.backprop(true); | |
MultiLayerConfiguration conf = listBuilder.build(); | |
return conf; | |
} | |
private void trainNetworkWithEarlyStop( | |
String trainFeatures, | |
String trainLabels, | |
String validationFeatures, | |
String validationLabels, | |
Integer minFileId, | |
Integer trainMaxFileId, | |
Integer validationMaxFileId, | |
MultiLayerNetwork net) { | |
DataSetIterator trainingData = getDataIterator(trainFeatures, trainLabels, minFileId, trainMaxFileId); | |
DataSetIterator testData = getDataIterator(validationFeatures, validationLabels, minFileId, validationMaxFileId); | |
EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>() | |
.epochTerminationConditions( | |
new MaxEpochsTerminationCondition(N_EPOCHS), | |
new ScoreImprovementEpochTerminationCondition(20) | |
) | |
.scoreCalculator(new DataSetLossCalculator(testData, true)) | |
.evaluateEveryNEpochs(1) | |
.build(); | |
int averagingFrequency = 3; | |
EarlyStoppingParallelTrainer<MultiLayerNetwork> wrapper = | |
new EarlyStoppingParallelTrainer<>( | |
esConf, | |
net, | |
trainingData, // use SequenceRecordReaderDataSetIterator | |
null, // MultiDataSetIterator is null | |
WORKERS, | |
PREFETCH_BUFFER, | |
averagingFrequency); | |
//Conduct early stopping training | |
EarlyStoppingResult<MultiLayerNetwork> result = wrapper.fit(); | |
System.out.println("Termination reason: " + result.getTerminationReason()); | |
System.out.println("Termination details: " + result.getTerminationDetails()); | |
System.out.println("Total epochs: " + result.getTotalEpochs()); | |
System.out.println("Best epoch number: " + result.getBestModelEpoch()); | |
System.out.println("Score at best epoch: " + result.getBestModelScore()); | |
storeNetworkModel(result.getBestModel()); | |
System.out.println("Successfully created and trained the NN model."); | |
} | |
private void storeNetworkModel(MultiLayerNetwork net) { | |
if (net == null) | |
return; | |
try { | |
File storedModel = new File(modelPath, modelName); | |
// deletes previously created model | |
storedModel.delete(); | |
FileOutputStream fos = new FileOutputStream(storedModel); | |
ModelSerializer.writeModel(net, fos, true); | |
} catch (IOException e) { | |
System.out.println("Failed storing model [" + modelName + "]."); | |
} | |
} | |
private DataSetIterator getDataIterator(String trainFeaturesPath, String trainLabelsPath, Integer minFileId, Integer maxFileId) { | |
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); | |
SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); | |
try { | |
File featuresDirTrain = new File(trainFeaturesPath); | |
File labelsDirTrain = new File(trainLabelsPath); | |
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", minFileId, maxFileId)); | |
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", minFileId, maxFileId)); | |
} catch (Exception e) { | |
System.out.println("Error when creating training data for RNN."); | |
} | |
DataSetIterator dataIterator = null; | |
dataIterator = new SequenceRecordReaderDataSetIterator( | |
trainFeatures, | |
trainLabels, | |
BATCH_SIZE, | |
CLASS_COUNT, | |
false, | |
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); | |
return dataIterator; | |
} | |
private MultiLayerNetwork createNetwork(MultiLayerConfiguration conf) { | |
MultiLayerNetwork net = new MultiLayerNetwork(conf); | |
net.init(); | |
net.setListeners(new ScoreIterationListener(1000)); | |
return net; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment