Skip to content

Instantly share code, notes, and snippets.

@JoseRFJuniorLLMs
Last active March 13, 2019 09:07
Show Gist options
  • Select an option

  • Save JoseRFJuniorLLMs/2b963d68b587bff9b998345c77347dcf to your computer and use it in GitHub Desktop.

Select an option

Save JoseRFJuniorLLMs/2b963d68b587bff9b998345c77347dcf to your computer and use it in GitHub Desktop.
IrisController,java
package com.booksgames.loja.controllers;
import com.booksgames.loja.documents.Iris;
import com.booksgames.loja.documents.IrisType;
import com.booksgames.loja.services.IrisClassifierService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.Map;
@CrossOrigin(origins = "*")
@RestController
public class IrisController {
@Autowired
IrisClassifierService irisClassifierService;
/* @GetMapping(value = "/iris/classify/class")
public IrisType classify(Iris iris) {
return irisClassifierService.classify(iris);
}
@GetMapping(value = "/iris/classify/probabilities")
public Map<IrisType, Float> classificationProbabilities(Iris iris) {
return irisClassifierService.classificationProbabilities(iris);
}*/
}
package com.booksgames.loja.services;
import com.booksgames.loja.documents.Iris;
import com.booksgames.loja.documents.IrisType;
import java.util.Map;
public interface IrisClassifierService {
/**
* Method to fetch a classification from the model
* @param iris the data to classify
* @return the predicted type
*/
IrisType classify(Iris iris);
/**
* Method to fetch from the model the probabilities of all the types
* @param iris the data to classify
* @return A map relating the type with its predicted probabilities
*/
Map<IrisType, Float> classificationProbabilities(Iris iris);
}
package com.booksgames.loja.services.impl;
import com.booksgames.loja.documents.Iris;
import com.booksgames.loja.documents.IrisType;
import com.booksgames.loja.services.IrisClassifierService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.util.HashMap;
import java.util.Map;
@Service
public class IrisTensorflowClassifierService implements IrisClassifierService {
//iris/classify/class?petalLength=4.4&petalWidth=1.4&sepalLength=6.7&sepalWidth=3.1
///iris/classify/probabilities?petalLength=1.3&petalWidth=0.3&sepalLength=5.0&sepalWidth=3.5
private final Session modelBundleSession;
private final IrisType[] irisTypes;
private final static String FEED_OPERATION = "dnn/input_from_feature_columns/input_layer/concat";
private final static String FETCH_OPERATION_PROBABILITIES = "dnn/head/predictions/probabilities";
private final static String FETCH_OPERATION_CLASS_ID = "dnn/head/predictions/class_ids";
@Autowired
public IrisTensorflowClassifierService(@Value("${loja.savedModel.path}") String savedModelPath,
@Value("${loja.savedModel.tags}") String savedModelTags) {
this.modelBundleSession = SavedModelBundle.load(savedModelPath, savedModelTags).session();
this.irisTypes = IrisType.values();
}
@Override
public IrisType classify(Iris iris) {
int category = this.fetchClassFromModel(iris);
return this.irisTypes[category];
}
@Override
public Map<IrisType, Float> classificationProbabilities(Iris iris){
Map<IrisType, Float> results = new HashMap<>(irisTypes.length);
float[][] vector = this.fetchProbabilitiesFromModel(iris);
int resultsCount = vector[0].length;
for (int i=0; i < resultsCount; i++){
results.put(irisTypes[i],vector[0][i]);
}
return results;
}
private float[][] fetchProbabilitiesFromModel(Iris iris) {
Tensor inputTensor = IrisTensorflowClassifierService.createInputTensor(iris);
Tensor result = this.modelBundleSession.runner()
.feed(IrisTensorflowClassifierService.FEED_OPERATION, inputTensor)
.fetch(IrisTensorflowClassifierService.FETCH_OPERATION_PROBABILITIES)
.run().get(0);
float[][] buffer = new float[1][3];
result.copyTo(buffer);
return buffer;
}
private int fetchClassFromModel(Iris iris){
Tensor inputTensor = IrisTensorflowClassifierService.createInputTensor(iris);
Tensor result = this.modelBundleSession.runner()
.feed(IrisTensorflowClassifierService.FEED_OPERATION, inputTensor)
.fetch(IrisTensorflowClassifierService.FETCH_OPERATION_CLASS_ID)
.run().get(0);
long[] buffer = new long[1];
result.copyTo(buffer);
return (int)buffer[0];
}
private static Tensor createInputTensor(Iris iris){
// order of the data on the input: PetalLength, PetalWidth, SepalLength, SepalWidth
// (taken from the saved_model, node dnn/input_from_feature_columns/input_layer/concat)
float[] input = {iris.getPetalLength(), iris.getPetalWidth(), iris.getSepalLength(), iris.getSepalWidth()};
float[][] data = new float[1][4];
data[0] = input;
return Tensor.create(data);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment