Skip to content

Instantly share code, notes, and snippets.

@MokkeMeguru
Created August 6, 2018 05:50
Show Gist options
  • Save MokkeMeguru/5b82460e3854055af6384248740348c1 to your computer and use it in GitHub Desktop.
Save MokkeMeguru/5b82460e3854055af6384248740348c1 to your computer and use it in GitHub Desktop.
TestDataSetIterator1
// TestDataSetIterator.java
// ver 1.0.0-beta
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public interface TestDataSetIterator1 {
static DataSetIterator createDataSetIterator () throws IOException, InterruptedException {
final List<Integer> lengthList = Arrays.asList(new Integer[]{5, 6, 7, 4, 4});
List<List<List<Writable>>> listlistlist = new ArrayList<>();
List<List<Writable>> listList = new ArrayList<>();
List<List<Writable>> labelList = new ArrayList<>();
List<List<List<Writable>>> listLabelList = new ArrayList<>();
int j = -1;
for (int length :
lengthList) {
List<Writable> list = new ArrayList<>();
for (int i = 1; i <= length; i++) {
list.add(new IntWritable(i));
}
listList.add(list);
List<Writable> label = new ArrayList<>();
label.add(new IntWritable(++j));
labelList.add(label);
}
listlistlist.add(listList);
listLabelList.add(labelList);
SequenceRecordReader testFeatures = new CollectionSequenceRecordReader(listlistlist);
System.out.println(testFeatures.sequenceRecord());
testFeatures.reset();
SequenceRecordReader testLabels = new CollectionSequenceRecordReader(listLabelList);
SequenceRecordReaderDataSetIterator data = new SequenceRecordReaderDataSetIterator(
testFeatures,
testLabels,
3,
7,
false,
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END
);
data.reset();
List<DataSet> trainData = new ArrayList<>();
int i = 0;
while(data.hasNext()) {
DataSet ds = data.next();
System.out.println(ds.getFeatures());
// INDArray labelArray = Nd4j.zeros(ds.getFeatures().shape()[0], 6);
// ...
}
return null;
}
static void main (String... args) throws IOException, InterruptedException {
DataSetIterator dsi = createDataSetIterator();
System.out.println(dsi);
}
}
// output
//
//[[[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
// [ 2.0000, 2.0000, 2.0000, 2.0000, 2.0000],
// [ 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],
// [ 4.0000, 4.0000, 4.0000, 4.0000, 4.0000],
// [ 5.0000, 5.0000, 5.0000, 0, 0]]]
// null
//
// I can't find 6.0000 or 7.0000
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment