Created
January 27, 2017 21:01
-
-
Save tomthetrainer/00662e318a5426537f4028124ef1ff03 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package org.deeplearning4j.VGGwebDemo; | |
| import org.datavec.image.loader.NativeImageLoader; | |
| import org.deeplearning4j.nn.graph.ComputationGraph; | |
| import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels; | |
| import org.deeplearning4j.util.ModelSerializer; | |
| import org.nd4j.linalg.api.ndarray.INDArray; | |
| import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | |
| import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor; | |
| import javax.servlet.MultipartConfigElement; | |
| import java.io.File; | |
| import java.io.InputStream; | |
| import java.nio.file.Files; | |
| import java.nio.file.Path; | |
| import java.nio.file.StandardCopyOption; | |
| import static spark.Spark.*; | |
| /** | |
| * Created by tomhanlon on 1/25/17. | |
| */ | |
| public class VGG16SparkJavaWebApp { | |
| public static void main(String[] args) throws Exception { | |
| String keyStoreLocation = "clientkeystore"; | |
| String keyStorePassword = "skymind"; | |
| secure(keyStoreLocation, keyStorePassword, null,null ); | |
| File locationToSave = new File("vgg16.zip"); | |
| ComputationGraph vgg16 = ModelSerializer.restoreComputationGraph(locationToSave); | |
| // make upload directory | |
| File uploadDir = new File("upload"); | |
| uploadDir.mkdir(); // create the upload directory if it doesn't exist | |
| // form | |
| String form = "<form method='post' action='getPredictions' enctype='multipart/form-data'>\n" + | |
| " <input type='file' name='uploaded_file'>\n" + | |
| " <button>Upload picture</button>\n" + | |
| "</form>"; | |
| staticFiles.location("/Users/tomhanlon/SkyMind/webcontent"); // Static files | |
| get("/hello", (req, res) -> "Hello World"); | |
| get("VGGpredict", (req, res) -> form); | |
| //post("getPredictions",(req, res) -> "GET RESULTS"); | |
| post("/getPredictions", (req, res) -> { | |
| Path tempFile = Files.createTempFile(uploadDir.toPath(), "", ""); | |
| req.attribute("org.eclipse.jetty.multipartConfig", new MultipartConfigElement("/temp")); | |
| try (InputStream input = req.raw().getPart("uploaded_file").getInputStream()) { // getPart needs to use same "name" as input field in form | |
| Files.copy(input, tempFile, StandardCopyOption.REPLACE_EXISTING); | |
| } | |
| //logInfo(req, tempFile); | |
| //return "<h1>You uploaded this image:<h1><img src='" + tempFile.getFileName() + "'>"; | |
| File file = tempFile.toFile(); | |
| //File file = new File(path); | |
| NativeImageLoader loader = new NativeImageLoader(224, 224, 3); | |
| INDArray image = loader.asMatrix(file); | |
| file.delete(); | |
| DataNormalization scaler = new VGG16ImagePreProcessor(); | |
| scaler.transform(image); | |
| //System.out.print(image); | |
| INDArray[] output = vgg16.output(false,image); | |
| String predictions = TrainedModels.VGG16.decodePredictions(output[0]); | |
| return "<h4> '" + predictions + "' </h4>" + | |
| "Would you like to try another" + | |
| form; | |
| //return "<h1>Your image is: '" + tempFile.getName(1).toString() + "' </h1>"; | |
| }); | |
| } | |
| } |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Ed,
The Main parts of this code are here..
These two lines.
File locationToSave = new File("vgg16.zip");
ComputationGraph vgg16 = ModelSerializer.restoreComputationGraph(locationToSave);
Load the neural Network.
The Neural Network is VGG16 from the pre-trained Keras models.
I can share that code as well, but basically I loaded it from Keras, tested it and then saved with ModelSerializer
It might be worth noting the size of the saved model.
490 MB
Once the model is loaded we need to feed it an image that the user submits in a form.
This code posts a form
String form = "
\n" +" \n" +
" Upload picture\n" +
"";
This html request prompts to load the form
get("VGGpredict", (req, res) -> form);
Or rather this sends a request for VGGpredict and that request for that url prompts the delivery of the form.
There is some web mumbo jumbo involved that we should gloss over, basically the file you load becomes the object "file"
File is shipped to nativeImageLoader and converted to an INDARRAY, in this case 3 matrices of 244* 244
The code here
NativeImageLoader loader = new NativeImageLoader(224, 224, 3);
INDArray image = loader.asMatrix(file);
file.delete();
Once we have it as a matrix we delete the file so our webapp does not fill up a hard drive.
We pre-process in the same way VGG16 was pre-processed
DataNormalization scaler = new VGG16ImagePreProcessor();
scaler.transform(image);
Susan's PreProcessor.. thanks Susan !!
This line asks the model for it's predictions (probability over 1,000 labels)
INDArray[] output = vgg16.output(false,image);
This line converts numbered output to string labels and takes top 5
Oops, should have noted that VGGpredict page sends form and delivers this page
"getPredictions"
All the code in this block is executed with each submit.
post("/getPredictions", (req, res) -> {
Note that it prints the predictions from the neural net and asks them if they would like to try again.
** Note that if you write this, the stuff I started to include here..
draft-VGG16Predict.md
Is worth adding, it was not designed for face recognition, it was designed for the imagenet challenge, it does great on lions, elephants and dogs, not so great on faces.