Created
February 10, 2021 07:39
-
-
Save lukasjapan/2af23822e3bd5c20d64870c1cc38aca2 to your computer and use it in GitHub Desktop.
using-deeplearning4j-to-distinguish-between-cats-and-dogs.3.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// load vgg16 model from model zoo | |
val model = VGG16().initPretrained(PretrainedType.IMAGENET) as ComputationGraph | |
// restore the new model that was saved to a file | |
val catsdogsModel = ModelSerializer.restoreMultiLayerNetwork(javaClass.getResource("/catdogmodel.dl4j").openStream()) | |
// get vgg16 labels | |
val image = FileInputStream("input.png") | |
val input = NativeImageLoader(224, 224, 3).asMatrix(image).also { VGG16ImagePreProcessor().transform(it) } | |
val vgg16labels = vgg16model.outputSingle(input) | |
// get cat/dog prediction values | |
val output = catsdogsModel.output(vgg16labels) | |
val cat = output.getDouble(0) | |
val dog = output.getDouble(1) | |
// make the prediction | |
if(cat > dog) println("cat") else println("dog") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment