Last active
March 13, 2019 09:07
-
-
Save JoseRFJuniorLLMs/2b963d68b587bff9b998345c77347dcf to your computer and use it in GitHub Desktop.
IrisController,java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package com.booksgames.loja.controllers; | |
| import com.booksgames.loja.documents.Iris; | |
| import com.booksgames.loja.documents.IrisType; | |
| import com.booksgames.loja.services.IrisClassifierService; | |
| import org.springframework.beans.factory.annotation.Autowired; | |
| import org.springframework.web.bind.annotation.CrossOrigin; | |
| import org.springframework.web.bind.annotation.GetMapping; | |
| import org.springframework.web.bind.annotation.RestController; | |
| import java.util.Map; | |
| @CrossOrigin(origins = "*") | |
| @RestController | |
| public class IrisController { | |
| @Autowired | |
| IrisClassifierService irisClassifierService; | |
| /* @GetMapping(value = "/iris/classify/class") | |
| public IrisType classify(Iris iris) { | |
| return irisClassifierService.classify(iris); | |
| } | |
| @GetMapping(value = "/iris/classify/probabilities") | |
| public Map<IrisType, Float> classificationProbabilities(Iris iris) { | |
| return irisClassifierService.classificationProbabilities(iris); | |
| }*/ | |
| } | |
| package com.booksgames.loja.services; | |
| import com.booksgames.loja.documents.Iris; | |
| import com.booksgames.loja.documents.IrisType; | |
| import java.util.Map; | |
| public interface IrisClassifierService { | |
| /** | |
| * Method to fetch a classification from the model | |
| * @param iris the data to classify | |
| * @return the predicted type | |
| */ | |
| IrisType classify(Iris iris); | |
| /** | |
| * Method to fetch from the model the probabilities of all the types | |
| * @param iris the data to classify | |
| * @return A map relating the type with its predicted probabilities | |
| */ | |
| Map<IrisType, Float> classificationProbabilities(Iris iris); | |
| } | |
| package com.booksgames.loja.services.impl; | |
| import com.booksgames.loja.documents.Iris; | |
| import com.booksgames.loja.documents.IrisType; | |
| import com.booksgames.loja.services.IrisClassifierService; | |
| import org.springframework.beans.factory.annotation.Autowired; | |
| import org.springframework.beans.factory.annotation.Value; | |
| import org.springframework.stereotype.Service; | |
| import org.tensorflow.SavedModelBundle; | |
| import org.tensorflow.Session; | |
| import org.tensorflow.Tensor; | |
| import java.util.HashMap; | |
| import java.util.Map; | |
| @Service | |
| public class IrisTensorflowClassifierService implements IrisClassifierService { | |
| //iris/classify/class?petalLength=4.4&petalWidth=1.4&sepalLength=6.7&sepalWidth=3.1 | |
| ///iris/classify/probabilities?petalLength=1.3&petalWidth=0.3&sepalLength=5.0&sepalWidth=3.5 | |
| private final Session modelBundleSession; | |
| private final IrisType[] irisTypes; | |
| private final static String FEED_OPERATION = "dnn/input_from_feature_columns/input_layer/concat"; | |
| private final static String FETCH_OPERATION_PROBABILITIES = "dnn/head/predictions/probabilities"; | |
| private final static String FETCH_OPERATION_CLASS_ID = "dnn/head/predictions/class_ids"; | |
| @Autowired | |
| public IrisTensorflowClassifierService(@Value("${loja.savedModel.path}") String savedModelPath, | |
| @Value("${loja.savedModel.tags}") String savedModelTags) { | |
| this.modelBundleSession = SavedModelBundle.load(savedModelPath, savedModelTags).session(); | |
| this.irisTypes = IrisType.values(); | |
| } | |
| @Override | |
| public IrisType classify(Iris iris) { | |
| int category = this.fetchClassFromModel(iris); | |
| return this.irisTypes[category]; | |
| } | |
| @Override | |
| public Map<IrisType, Float> classificationProbabilities(Iris iris){ | |
| Map<IrisType, Float> results = new HashMap<>(irisTypes.length); | |
| float[][] vector = this.fetchProbabilitiesFromModel(iris); | |
| int resultsCount = vector[0].length; | |
| for (int i=0; i < resultsCount; i++){ | |
| results.put(irisTypes[i],vector[0][i]); | |
| } | |
| return results; | |
| } | |
| private float[][] fetchProbabilitiesFromModel(Iris iris) { | |
| Tensor inputTensor = IrisTensorflowClassifierService.createInputTensor(iris); | |
| Tensor result = this.modelBundleSession.runner() | |
| .feed(IrisTensorflowClassifierService.FEED_OPERATION, inputTensor) | |
| .fetch(IrisTensorflowClassifierService.FETCH_OPERATION_PROBABILITIES) | |
| .run().get(0); | |
| float[][] buffer = new float[1][3]; | |
| result.copyTo(buffer); | |
| return buffer; | |
| } | |
| private int fetchClassFromModel(Iris iris){ | |
| Tensor inputTensor = IrisTensorflowClassifierService.createInputTensor(iris); | |
| Tensor result = this.modelBundleSession.runner() | |
| .feed(IrisTensorflowClassifierService.FEED_OPERATION, inputTensor) | |
| .fetch(IrisTensorflowClassifierService.FETCH_OPERATION_CLASS_ID) | |
| .run().get(0); | |
| long[] buffer = new long[1]; | |
| result.copyTo(buffer); | |
| return (int)buffer[0]; | |
| } | |
| private static Tensor createInputTensor(Iris iris){ | |
| // order of the data on the input: PetalLength, PetalWidth, SepalLength, SepalWidth | |
| // (taken from the saved_model, node dnn/input_from_feature_columns/input_layer/concat) | |
| float[] input = {iris.getPetalLength(), iris.getPetalWidth(), iris.getSepalLength(), iris.getSepalWidth()}; | |
| float[][] data = new float[1][4]; | |
| data[0] = input; | |
| return Tensor.create(data); | |
| } | |
| } | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment