Skip to content

Instantly share code, notes, and snippets.

@matwey
Last active September 11, 2020 09:32
Show Gist options
  • Save matwey/82237d2d5b38213d3d7fce3ddb318e65 to your computer and use it in GitHub Desktop.
Save matwey/82237d2d5b38213d3d7fce3ddb318e65 to your computer and use it in GitHub Desktop.
dump/load IsolationForest to/from ONNX
#!/usr/bin/env python3
import numpy as np
from sklearn.ensemble import IsolationForest
# Need skl2onnx 1.7.1+
from skl2onnx import convert_sklearn
from skl2onnx import to_onnx
rng = np.random.default_rng()
x1 = rng.multivariate_normal(np.array([10,10]), np.diag([1,1]), 20)
x2 = rng.multivariate_normal(np.array([-10,-10]), np.diag([1,1]), 20)
X = np.vstack([x1,x2])
X = X.astype(np.float32)
clf = IsolationForest(random_state=0).fit(X)
model_onnx = to_onnx(clf, X)
with open("isof.onnx", "wb") as f:
f.write(model_onnx.SerializeToString())
#!/usr/bin/env python
import onnxruntime as rt
import numpy as np
X_test = np.array([[0.0, 0.0], [10.0, 10.0], [-10.0, -10.0], [10.0, 0.0]])
sess = rt.InferenceSession("isof.onnx")
print("Inputs: {}".format([x.name for x in sess.get_inputs()]))
print("Outputs: {}".format([x.name for x in sess.get_outputs()]))
input_name = sess.get_inputs()[0].name
label_name = 'scores'
pred_onx = sess.run([label_name], {input_name: X_test.astype(np.float32)})
print(pred_onx)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment