Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save lukasjapan/3e3c94f92b58349e032d745706b4f28f to your computer and use it in GitHub Desktop.
Save lukasjapan/3e3c94f92b58349e032d745706b4f28f to your computer and use it in GitHub Desktop.
using-deeplearning4j-to-distinguish-between-cats-and-dogs.2.kt
// designing the model - mostly the same as shown at https://deeplearning4j.org/mnist-for-beginners
val conf = NeuralNetConfiguration.Builder()
.seed(ThreadLocalRandom.current().nextLong())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(0.006)
.updater(Nesterovs(0.9))
.regularization(true).l2(1e-4)
.list()
.layer(0, DenseLayer.Builder()
.nIn(1000) // Number of input datapoints.
.nOut(1000) // Number of output datapoints.
.activation(ActivationReLU()) // Activation function.
.weightInit(WeightInit.XAVIER) // Weight initialization.
.build()
)
.layer(1, OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(1000)
.nOut(2)
.activation(ActivationSoftmax())
.weightInit(WeightInit.XAVIER)
.build()
)
.pretrain(false).backprop(true)
.build()
val model = MultiLayerNetwork(conf)
model.init()
// define the answeres
val cat = Nd4j.create(doubleArrayOf(1.0,0.0))
val dog = Nd4j.create(doubleArrayOf(0.0,1.0))
// catLabels and dogLabels hold the input labels from the training set in a list
val trainingData = catLabels.map { DataSet(it, cat) } + dogLabels.map { DataSet(it, dog) }
// don't show dogs and cats one after another
Collections.shuffle(trainingData)
// feed all data to the model
model.fit(ListDataSetIterator(trainingData))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment