Created
July 21, 2020 19:33
-
-
Save AbdealiLoKo/1dd5b7677435ba22f9ab3e26016bb3e7 to your computer and use it in GitHub Desktop.
Comparing py-java libraries
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Example: | |
# PYJAVA_LIB=jpype venv/bin/python pyjava.py | |
import os | |
from datetime import datetime | |
from jpmml_evaluator import _package_classpath | |
lib = os.environ.get('PYJAVA_LIB') | |
assert lib is not None, 'Set env var PYJAVA_LIB to py4j/jnius/jpype' | |
##### Create a JVM | |
start_t = datetime.now() | |
if lib == 'py4j': | |
from py4j.java_gateway import JavaGateway | |
gateway = JavaGateway.launch_gateway(classpath=os.pathsep.join(_package_classpath())) | |
jString = gateway.jvm.__getattr__('java.lang.String') | |
jDouble = gateway.jvm.__getattr__('java.lang.Double') | |
jFile = gateway.jvm.__getattr__('java.io.File') | |
jLinkedHashMap = gateway.jvm.__getattr__('java.util.LinkedHashMap') | |
jLoadingModelEvaluatorBuilder = gateway.jvm.__getattr__('org.jpmml.evaluator.LoadingModelEvaluatorBuilder') | |
jModelEvaluationContext = gateway.jvm.__getattr__('org.jpmml.evaluator.ModelEvaluationContext') | |
jEvaluatorUtil = gateway.jvm.__getattr__('org.jpmml.evaluator.EvaluatorUtil') | |
elif lib == 'jnius': | |
import jnius_config | |
jnius_config.set_classpath(*_package_classpath()) | |
import jnius | |
jString = jnius.autoclass('java.lang.String') | |
jDouble = jnius.autoclass('java.lang.Double') | |
jFile = jnius.autoclass('java.io.File') | |
jLinkedHashMap = jnius.autoclass('java.util.LinkedHashMap') | |
jLoadingModelEvaluatorBuilder = jnius.autoclass('org.jpmml.evaluator.LoadingModelEvaluatorBuilder') | |
jModelEvaluationContext = jnius.autoclass('org.jpmml.evaluator.ModelEvaluationContext') | |
jEvaluatorUtil = jnius.autoclass('org.jpmml.evaluator.EvaluatorUtil') | |
elif lib == 'jpype': | |
import jpype | |
import jpype.imports | |
jpype.startJVM(classpath=_package_classpath()) | |
from java.lang import String as jString | |
from java.lang import Double as jDouble | |
from java.io import File as jFile | |
from java.util import LinkedHashMap as jLinkedHashMap | |
from org.jpmml.evaluator import LoadingModelEvaluatorBuilder as jLoadingModelEvaluatorBuilder | |
from org.jpmml.evaluator import ModelEvaluationContext as jModelEvaluationContext | |
from org.jpmml.evaluator import EvaluatorUtil as jEvaluatorUtil | |
time = (datetime.now() - start_t).total_seconds() | |
print(f"createjvm: {time:.3f}s") | |
##### Load a Model | |
times = [] | |
for _ in range(100): | |
start_t = datetime.now() | |
evaluatorBuilder = jLoadingModelEvaluatorBuilder() | |
evaluatorBuilder.setLocatable(True) | |
evaluatorBuilder.load(jFile(jString("jpmml_evaluator/tests/resources/DecisionTreeIris.pmml"))) | |
evaluator = evaluatorBuilder.build().verify() | |
times.append((datetime.now() - start_t).total_seconds()) | |
print(f"loadmodel: tot={sum(times):.6f} max={max(times):.6f}s avg={sum(times) / len(times):.6f}s") | |
##### Query field info | |
times = [] | |
for _ in range(100): | |
start_t = datetime.now() | |
inputFields = evaluator.getInputFields() | |
vals = [inputField.getName() for inputField in inputFields] | |
# print("Input fields: ", vals) | |
targetFields = evaluator.getTargetFields() | |
vals = [targetField.getName() for targetField in targetFields] | |
# print("Target field(s): ", vals) | |
outputFields = evaluator.getOutputFields() | |
vals = [outputField.getName() for outputField in outputFields] | |
# print("Output fields: ", vals) | |
times.append((datetime.now() - start_t).total_seconds()) | |
print(f"fields : tot={sum(times):.6f} max={max(times):.6f}s avg={sum(times) / len(times):.6f}s") | |
##### Score | |
times = [] | |
for i in range(100): | |
val = i * 0.001 | |
start_t = datetime.now() | |
arguments1 = { | |
"Sepal.Length" : 5.1 + val, | |
"Sepal.Width" : 3.5 + val, | |
"Petal.Length" : 1.4 + val, | |
"Petal.Width" : 0.2 + val, | |
} | |
arguments2 = jLinkedHashMap() | |
for k, v in arguments1.items(): | |
arguments2.put(jString(k), jDouble(v)) | |
arguments = jEvaluatorUtil.encodeKeys(arguments2) | |
results1 = evaluator.evaluate(arguments) | |
results2 = jEvaluatorUtil.decodeAll(results1) | |
times.append((datetime.now() - start_t).total_seconds()) | |
print(f"score : tot={sum(times):.6f} max={max(times):.6f}s avg={sum(times) / len(times):.6f}s") | |
# jpype | |
# createjvm: 0.550s | |
# loadmodel: tot=1.466451 max=1.064521s avg=0.014665s | |
# fields : tot=0.019881 max=0.009795s avg=0.000199s | |
# score : tot=0.033356 max=0.023338s avg=0.000334s | |
# jnius | |
# createjvm: 0.249s | |
# loadmodel: tot=1.773011 max=1.385274s avg=0.017730s | |
# fields : tot=0.039058 max=0.012234s avg=0.000391s | |
# score : tot=0.067590 max=0.031904s avg=0.000676s | |
# py4j | |
# createjvm: 0.222s | |
# loadmodel: tot=0.616913 max=0.027464s avg=0.006169s | |
# fields : tot=0.699152 max=0.026426s avg=0.006992s | |
# score : tot=0.389583 max=0.017620s avg=0.003896s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment