Skip to content

Instantly share code, notes, and snippets.

@MokkeMeguru
Created August 5, 2018 22:45
Show Gist options
  • Save MokkeMeguru/53d5e79f92d89fac8ea81daf37397dfe to your computer and use it in GitHub Desktop.
Save MokkeMeguru/53d5e79f92d89fac8ea81daf37397dfe to your computer and use it in GitHub Desktop.
TestDataSetIterator.java
// TestDataSetIterator.java
// ver 1.0.0-beta
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.io.File;
import java.io.IOException;
public interface TestDataSetIterator {
static DataSetIterator createDataSetIterator () throws IOException, InterruptedException {
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(0,",");
testFeatures.initialize(new FileSplit(new File("resources/test_train.csv")));
SequenceRecordReader testLabels = new CSVSequenceRecordReader(0,",");
testLabels.initialize(new FileSplit(new File("resources/test_label.csv")));
return new SequenceRecordReaderDataSetIterator(
testFeatures,
testLabels,
3,
6,
false,
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END
);
}
static void main (String... args) throws IOException, InterruptedException {
DataSetIterator dsi = createDataSetIterator();
System.out.println(dsi.next());
}
}
// --- test_train.csv ---
// 1,2,3,4,5
// 1,2,3,4,5,6
// 1,2,3,4,5,6,7
// 1,2,3
// 1,2,3,4,5,6
// -----------------------
// --- test_label.csv ---
// 1
// 2
// 3
// 4
// 5
// -----------------------
//
// Exception in thread "main" java.lang.IndexOutOfBoundsException: 25
// at org.bytedeco.javacpp.indexer.Indexer.checkIndex(Indexer.java:90)
// at org.bytedeco.javacpp.indexer.FloatRawIndexer.put(FloatRawIndexer.java:90)
// at org.nd4j.linalg.api.buffer.BaseDataBuffer.put(BaseDataBuffer.java:1116)
// at org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer.put(BaseCudaDataBuffer.java:684)
// at org.nd4j.linalg.api.ndarray.BaseNDArray.putScalar(BaseNDArray.java:1414)
// at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.convertWritablesSequence(RecordReaderMultiDataSetIterator.java:661)
// at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.convertFeaturesOrLabels(RecordReaderMultiDataSetIterator.java:367)
// at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.nextMultiDataSet(RecordReaderMultiDataSetIterator.java:325)
// at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.next(RecordReaderMultiDataSetIterator.java:213)
// at org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator.next(SequenceRecordReaderDataSetIterator.java:345)
// at org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator.next(SequenceRecordReaderDataSetIterator.java:324)
// at org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator.next(SequenceRecordReaderDataSetIterator.java:32)
// at TestDataSetIterator.main(TestDataSetIterator.java:28)
//
// Process finished with exit code 1
// But
//
// --- test_train.csv ---
// 1,2,3,4,5,6,7
// 1,2,3,4,5,6
// 1,2,3,4,5
// 1,2,3
// 1,2,3,4,5,6
// -----------------------
// --- test_label.csv ---
// 1
// 2
// 3
// 4
// 5
// -----------------------
//
// ===========INPUT===================
// [[[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
// [ 2.0000, 2.0000, 2.0000, 2.0000, 2.0000],
// [ 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],
// [ 4.0000, 4.0000, 4.0000, 0, 4.0000],
// [ 5.0000, 5.0000, 5.0000, 0, 5.0000],
// [ 6.0000, 0, 6.0000, 0, 6.0000],
// [ 7.0000, 0, 0, 0, 0]]]
// =================OUTPUT==================
// [[[ 0, 0, 0, 0, 0],
// [ 1.0000, 0, 0, 0, 0],
// [ 0, 1.0000, 0, 0, 0],
// [ 0, 0, 1.0000, 0, 0],
// [ 0, 0, 0, 1.0000, 0],
// [ 0, 0, 0, 0, 1.0000]]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment