Skip to content

Instantly share code, notes, and snippets.

@holgerbrandl
Last active May 9, 2018 08:53
Show Gist options
  • Save holgerbrandl/09fa9fddf8f337530db4e4a18ea569b1 to your computer and use it in GitHub Desktop.
Save holgerbrandl/09fa9fddf8f337530db4e4a18ea569b1 to your computer and use it in GitHub Desktop.
package playground
import DATA_ROOT
import org.datavec.api.io.labels.PathLabelGenerator
import org.datavec.api.split.CollectionInputSplit
import org.datavec.api.writable.IntWritable
import org.datavec.api.writable.Writable
import org.datavec.image.recordreader.ImageRecordReader
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator
import java.io.File
import java.net.URI
/**
* @author Holger Brandl
*/
fun buildDataSetIterator(path: File, batchSize: Int = 64): RecordReaderDataSetIterator {
val pathLabelGen = object: ParentPathLabelGenerator() {
override fun getLabelForPath(path: String?): Writable {
val label = if(File(path).nameWithoutExtension.endsWith("1")) 1 else 0
return IntWritable(label)
}
}
val recordReader = object : ImageRecordReader(224, 224, 3, pathLabelGen) {
init {
val jpgFiles = path.listFiles({ file -> file.extension == "jpg" })
initialize(CollectionInputSplit(jpgFiles.map { it.toURI() }))
}
}
return RecordReaderDataSetIterator(recordReader, batchSize)
}
fun main(args: Array<String>) {
// val dsIterator = createRecReaderDataIterator(File(DATA_ROOT, "train_photos"), batchSize = 10)
val dsIterator = buildDataSetIterator(File(DATA_ROOT, "train_photos"), batchSize = 10)
for (dataSet in dsIterator) {
val featureMatrix = dataSet.featureMatrix
val labels = dataSet.labels
print("features: ${featureMatrix.shapeInfoToString()}")
print("labels ${labels.shapeInfoToString()}")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment