Created
April 19, 2017 19:56
-
-
Save aminnj/6ba2545ee82654e8dd479b6e7f2a6669 to your computer and use it in GitHub Desktop.
XGB model to C++ function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
######################################## | |
## # Train BDT and output model .json ## | |
######################################## | |
import numpy as np | |
from sklearn.model_selection import train_test_split | |
import sys | |
import os | |
import xgboost as xgb | |
### Define test and train matrices somehow | |
### Train BDT | |
dtrain = xgb.DMatrix( X_train, label=y_train) | |
dtest = xgb.DMatrix( X_test, label=y_test) | |
evallist = [(dtest,'eval'), (dtrain,'train')] | |
num_round = 800 | |
param = {} | |
param['max_depth'] = 5 | |
param['eval_metric'] = "auc" | |
bst = xgb.train( param.items(), dtrain, num_round, evallist, early_stopping_rounds=20 ) | |
y_pred = bst.predict(dtest) | |
# Dump BDT into model file | |
with open("model.json", "w") as fhout: | |
fhout.write("[\n"+",\n".join(bst.get_dump(dump_format="json"))+"\n]") | |
######################################## | |
#### # Parse .json in separate file #### | |
######################################## | |
import json | |
import pprint | |
import numpy as np | |
nfeatures = 17 # for now, have to hardcode number of variables | |
with open("model.json", "r") as fhin: | |
js = json.loads(fhin.read()) | |
def get_leaf(d_features, entry, depth=0): | |
# recursively venture into tree and collect leaf values, | |
# making a nested set of C++ ternary statements | |
if "leaf" in entry: | |
return entry["leaf"] | |
splitvar = entry["split"] | |
splitval = entry["split_condition"] | |
yesnode = [c for c in entry["children"] if c["nodeid"] == entry["yes"]][0] | |
nonode = [c for c in entry["children"] if c["nodeid"] == entry["no"]][0] | |
return "({} < {} ? {} : {})".format(splitvar, splitval, get_leaf(d_features,yesnode, depth=depth+1), get_leaf(d_features,nonode, depth=depth+1)) | |
# Print c++ function to stdout, so will need to redirect | |
# like: `python blah.py > blah.h` | |
colnames = ["float f"+str(ic) for ic in range(nfeatures)] | |
print "float get_prediction({}) {{".format(",".join(colnames)) | |
print " float w = 0.;" | |
for j in js: | |
print " w += {};".format(get_leaf(vals, j)) | |
print " return 1.0/(1.0+exp(-1.0*w));" | |
print "}" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment