Skip to content

Instantly share code, notes, and snippets.

@tomthetrainer
Created January 27, 2017 21:01
Show Gist options
  • Select an option

  • Save tomthetrainer/00662e318a5426537f4028124ef1ff03 to your computer and use it in GitHub Desktop.

Select an option

Save tomthetrainer/00662e318a5426537f4028124ef1ff03 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.VGGwebDemo;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
import javax.servlet.MultipartConfigElement;
import java.io.File;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import static spark.Spark.*;
/**
* Created by tomhanlon on 1/25/17.
*/
public class VGG16SparkJavaWebApp {
public static void main(String[] args) throws Exception {
String keyStoreLocation = "clientkeystore";
String keyStorePassword = "skymind";
secure(keyStoreLocation, keyStorePassword, null,null );
File locationToSave = new File("vgg16.zip");
ComputationGraph vgg16 = ModelSerializer.restoreComputationGraph(locationToSave);
// make upload directory
File uploadDir = new File("upload");
uploadDir.mkdir(); // create the upload directory if it doesn't exist
// form
String form = "<form method='post' action='getPredictions' enctype='multipart/form-data'>\n" +
" <input type='file' name='uploaded_file'>\n" +
" <button>Upload picture</button>\n" +
"</form>";
staticFiles.location("/Users/tomhanlon/SkyMind/webcontent"); // Static files
get("/hello", (req, res) -> "Hello World");
get("VGGpredict", (req, res) -> form);
//post("getPredictions",(req, res) -> "GET RESULTS");
post("/getPredictions", (req, res) -> {
Path tempFile = Files.createTempFile(uploadDir.toPath(), "", "");
req.attribute("org.eclipse.jetty.multipartConfig", new MultipartConfigElement("/temp"));
try (InputStream input = req.raw().getPart("uploaded_file").getInputStream()) { // getPart needs to use same "name" as input field in form
Files.copy(input, tempFile, StandardCopyOption.REPLACE_EXISTING);
}
//logInfo(req, tempFile);
//return "<h1>You uploaded this image:<h1><img src='" + tempFile.getFileName() + "'>";
File file = tempFile.toFile();
//File file = new File(path);
NativeImageLoader loader = new NativeImageLoader(224, 224, 3);
INDArray image = loader.asMatrix(file);
file.delete();
DataNormalization scaler = new VGG16ImagePreProcessor();
scaler.transform(image);
//System.out.print(image);
INDArray[] output = vgg16.output(false,image);
String predictions = TrainedModels.VGG16.decodePredictions(output[0]);
return "<h4> '" + predictions + "' </h4>" +
"Would you like to try another" +
form;
//return "<h1>Your image is: '" + tempFile.getName(1).toString() + "' </h1>";
});
}
}
@tomthetrainer
Copy link
Author

Ed,

The Main parts of this code are here..

These two lines.
File locationToSave = new File("vgg16.zip");
ComputationGraph vgg16 = ModelSerializer.restoreComputationGraph(locationToSave);

Load the neural Network.

The Neural Network is VGG16 from the pre-trained Keras models.
I can share that code as well, but basically I loaded it from Keras, tested it and then saved with ModelSerializer

It might be worth noting the size of the saved model.

490 MB

Once the model is loaded we need to feed it an image that the user submits in a form.

This code posts a form

String form = "

\n" +
" \n" +
" Upload picture\n" +
"";

This html request prompts to load the form
get("VGGpredict", (req, res) -> form);

Or rather this sends a request for VGGpredict and that request for that url prompts the delivery of the form.

There is some web mumbo jumbo involved that we should gloss over, basically the file you load becomes the object "file"

File is shipped to nativeImageLoader and converted to an INDARRAY, in this case 3 matrices of 244* 244

The code here

NativeImageLoader loader = new NativeImageLoader(224, 224, 3);
INDArray image = loader.asMatrix(file);
file.delete();

Once we have it as a matrix we delete the file so our webapp does not fill up a hard drive.

We pre-process in the same way VGG16 was pre-processed

DataNormalization scaler = new VGG16ImagePreProcessor();
scaler.transform(image);

Susan's PreProcessor.. thanks Susan !!

This line asks the model for it's predictions (probability over 1,000 labels)
INDArray[] output = vgg16.output(false,image);

This line converts numbered output to string labels and takes top 5

        String predictions = TrainedModels.VGG16.decodePredictions(output[0]);

Oops, should have noted that VGGpredict page sends form and delivers this page
"getPredictions"
All the code in this block is executed with each submit.
post("/getPredictions", (req, res) -> {

Note that it prints the predictions from the neural net and asks them if they would like to try again.

** Note that if you write this, the stuff I started to include here..
draft-VGG16Predict.md

Is worth adding, it was not designed for face recognition, it was designed for the imagenet challenge, it does great on lions, elephants and dogs, not so great on faces.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment