Skip to content

Instantly share code, notes, and snippets.

@qharlie
Created August 7, 2016 20:10
Show Gist options
  • Save qharlie/2f044f864ce6b975a18c0ff5ac27bcbc to your computer and use it in GitHub Desktop.
Save qharlie/2f044f864ce6b975a18c0ff5ac27bcbc to your computer and use it in GitHub Desktop.
public static void main(String[] args) throws IOException, InterruptedException {
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new FileSplit(new org.springframework.core.io.ClassPathResource("g.csv").getFile()));
int labelIndex = 5; // 1 based
int batchSize = 50;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex);
org.nd4j.linalg.dataset.api.DataSet next = iterator.next();
SplitTestAndTrain testAndTrain = next.splitTestAndTrain(0.8); //Use 80% of data for training
//...
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment