Created
April 14, 2017 14:13
-
-
Save edumucelli/fd0cfdcb621e6f7154019c855acfeb4e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library("caret") | |
library("r2pmml") | |
data(iris) | |
rf_fit = train(Species ~ ., data = iris, method = "rf") | |
print(rf_fit) | |
r2pmml(rf_fit, "rf_fit.pmml") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.dmg.pmml.FieldName; | |
import org.jpmml.evaluator.FieldValue; | |
import org.jpmml.evaluator.ModelEvaluator; | |
import org.jpmml.evaluator.ProbabilityDistribution; | |
import java.util.Map; | |
import java.util.concurrent.Callable; | |
class ParallelPredictor implements Callable { | |
private Map<FieldName, FieldValue> arguments; | |
private ModelEvaluator<?> evaluator; | |
public ParallelPredictor(Map<FieldName, FieldValue> arguments, ModelEvaluator<?> evaluator) { | |
this.arguments = arguments; | |
this.evaluator = evaluator; | |
} | |
@Override | |
public Double call() throws Exception { | |
return ((ProbabilityDistribution) evaluator.evaluate(arguments).get(evaluator.getTargetFieldName())).getProbability("versicolor"); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import lombok.extern.slf4j.Slf4j; | |
import org.dmg.pmml.DataType; | |
import org.dmg.pmml.FieldName; | |
import org.dmg.pmml.OpType; | |
import org.dmg.pmml.PMML; | |
import org.jpmml.evaluator.*; | |
import org.xml.sax.SAXException; | |
import javax.xml.bind.JAXBException; | |
import java.io.File; | |
import java.io.FileInputStream; | |
import java.io.IOException; | |
import java.io.InputStream; | |
import java.time.Duration; | |
import java.time.Instant; | |
import java.util.*; | |
import java.util.concurrent.Callable; | |
import java.util.concurrent.ExecutorService; | |
import java.util.concurrent.Executors; | |
import java.util.stream.Collectors; | |
@Slf4j | |
class Predictor { | |
private String modelFilename; | |
private ModelEvaluator<?> evaluator; | |
private static final Random random = new Random(); | |
private static final int NUMBER_OF_THREADS = 100; | |
private ExecutorService executor = Executors.newFixedThreadPool(NUMBER_OF_THREADS); | |
Predictor(String modelFilename) { | |
this.modelFilename = modelFilename; | |
} | |
void buildEvaluator() { | |
PMML pmml = null; | |
File inputFilePath = new File(this.modelFilename); | |
try(InputStream in = new FileInputStream(inputFilePath)) { | |
pmml = org.jpmml.model.PMMLUtil.unmarshal(in); | |
} catch (SAXException | JAXBException | IOException e) { | |
e.printStackTrace(); | |
} | |
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); | |
this.evaluator = modelEvaluatorFactory.newModelEvaluator(pmml); | |
} | |
void predict() { | |
List<Callable<Double>> callableArguments = new ArrayList<>(); | |
int numberOfRepeats = 33; | |
int numberOfRows = 100; | |
for (int j = 0; j < numberOfRepeats; j++) { | |
for (int i = 0; i < numberOfRows; i++) { | |
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>(); | |
FieldName sepalLengthName = FieldName.create("Sepal.Length"); | |
FieldName sepalWidthName = FieldName.create("Sepal.Width"); | |
FieldName petalLengthName = FieldName.create("Petal.Length"); | |
FieldName petalWidthName = FieldName.create("Petal.Width"); | |
FieldValue sepalLengthValue = FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, randomValue()); | |
FieldValue sepalWidthValue = FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, randomValue()); | |
FieldValue petalLengthValue = FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, randomValue()); | |
FieldValue petalWidthValue = FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, randomValue()); | |
arguments.put(sepalLengthName, sepalLengthValue); | |
arguments.put(sepalWidthName, sepalWidthValue); | |
arguments.put(petalLengthName, petalLengthValue); | |
arguments.put(petalWidthName, petalWidthValue); | |
callableArguments.add(new ParallelPredictor(arguments, evaluator)); | |
} | |
try { | |
Instant start = Instant.now(); | |
executor.invokeAll(callableArguments) | |
.stream() | |
.map(future -> { | |
try { | |
return future.get(); | |
} catch (Exception e) { | |
throw new IllegalStateException(e); | |
} | |
}) | |
.collect(Collectors.toList()); | |
Instant end = Instant.now(); | |
log.info(String.valueOf(Duration.between(start, end).toMillis())); | |
} catch (InterruptedException e) { | |
e.printStackTrace(); | |
} | |
} | |
} | |
public Double randomValue() { | |
return 1 + (10 - 1) * random.nextDouble(); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class Runner { | |
public static void main(String[] args) { | |
Predictor predictor = new Predictor("rf_fit.pmml"); | |
predictor.buildEvaluator(); | |
predictor.predict(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment