Skip to content

Instantly share code, notes, and snippets.

@aminnj
Created April 19, 2017 19:56
Show Gist options
  • Save aminnj/6ba2545ee82654e8dd479b6e7f2a6669 to your computer and use it in GitHub Desktop.
Save aminnj/6ba2545ee82654e8dd479b6e7f2a6669 to your computer and use it in GitHub Desktop.
XGB model to C++ function
########################################
## # Train BDT and output model .json ##
########################################
import numpy as np
from sklearn.model_selection import train_test_split
import sys
import os
import xgboost as xgb
### Define test and train matrices somehow
### Train BDT
dtrain = xgb.DMatrix( X_train, label=y_train)
dtest = xgb.DMatrix( X_test, label=y_test)
evallist = [(dtest,'eval'), (dtrain,'train')]
num_round = 800
param = {}
param['max_depth'] = 5
param['eval_metric'] = "auc"
bst = xgb.train( param.items(), dtrain, num_round, evallist, early_stopping_rounds=20 )
y_pred = bst.predict(dtest)
# Dump BDT into model file
with open("model.json", "w") as fhout:
fhout.write("[\n"+",\n".join(bst.get_dump(dump_format="json"))+"\n]")
########################################
#### # Parse .json in separate file ####
########################################
import json
import pprint
import numpy as np
nfeatures = 17 # for now, have to hardcode number of variables
with open("model.json", "r") as fhin:
js = json.loads(fhin.read())
def get_leaf(d_features, entry, depth=0):
# recursively venture into tree and collect leaf values,
# making a nested set of C++ ternary statements
if "leaf" in entry:
return entry["leaf"]
splitvar = entry["split"]
splitval = entry["split_condition"]
yesnode = [c for c in entry["children"] if c["nodeid"] == entry["yes"]][0]
nonode = [c for c in entry["children"] if c["nodeid"] == entry["no"]][0]
return "({} < {} ? {} : {})".format(splitvar, splitval, get_leaf(d_features,yesnode, depth=depth+1), get_leaf(d_features,nonode, depth=depth+1))
# Print c++ function to stdout, so will need to redirect
# like: `python blah.py > blah.h`
colnames = ["float f"+str(ic) for ic in range(nfeatures)]
print "float get_prediction({}) {{".format(",".join(colnames))
print " float w = 0.;"
for j in js:
print " w += {};".format(get_leaf(vals, j))
print " return 1.0/(1.0+exp(-1.0*w));"
print "}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment