Skip to content

Instantly share code, notes, and snippets.

@fanwei918
Forked from raddy/xgboost_to_c.py
Created November 23, 2017 18:04
Show Gist options
  • Save fanwei918/cdd6e75aecbefaa38ec8c2cb3c1fca30 to your computer and use it in GitHub Desktop.
Save fanwei918/cdd6e75aecbefaa38ec8c2cb3c1fca30 to your computer and use it in GitHub Desktop.
Convert XGBoost model to C header
import re
_comparer = None
import contextlib
#ALWAYS_INLINE = "__attribute__((__always_inline__))"
ALWAYS_INLINE = "ALWAYS_INLINE"
class CodeGenerator(object):
def __init__(self):
self._lines = []
self._indent = 0
@property
def lines(self):
return self._lines
def write(self, line):
self._lines.append(" " * self._indent + line)
@contextlib.contextmanager
def bracketed(self, preamble, postamble):
assert self._indent >= 0
self.write(preamble)
self._indent += 1
yield
self._indent -= 1
self.write(postamble)
class XgbModel:
def __init__(self, verbosity = 0):
self._verbosity = verbosity
self.XgbTrees = []
self._treeIndex = 0
self._maxDeepening = 0
self._pathMemo = []
self._maxInteractionDepth = 0
def AddTree(self, tree):
self.XgbTrees.append(tree)
def code_gen_ensemble(self,fn,gen=None):
if gen is None:
gen = CodeGenerator()
num_trees = 0
for i, tree in enumerate(self.XgbTrees):
name = "{name}_{index}".format(name=fn+'_boost', index=i)
self.code_gen_tree(tree,name, gen)
num_trees+=1
fn_decl = "{inline} double {name}(double* f) {{".format(inline=ALWAYS_INLINE,name=fn)
with gen.bracketed(fn_decl, "}"):
gen.write("double result = 0.;")
for i in range(num_trees):
increment = "result += {name}_{index}(f);".format(
name=fn+'_boost',index=i)
gen.write(increment)
gen.write("return result;")
return gen.lines
def xgb_to_c(self,fn):
lines = self.code_gen_ensemble(fn=fn)
header = "#pragma once\n"+"namespace atm { namespace trader { namespace models {\n\n"
footer = "\n\n }}}"
assert lines is not None
code = "\n".join(lines)
return header + code + footer
def code_gen_tree(self,tree,fn='boost', gen=None):
if gen is None:
gen = CodeGenerator()
def recur(tree):
if tree.node.IsLeaf:
gen.write("return {0};".format(tree.node.LeafValue))
return
branch = "if (f[{feature}] <= {threshold}f) {{".format(
feature=tree.node.Feature[1:],
threshold=tree.node.SplitValue)
with gen.bracketed(branch, "}"):
recur(tree.left)
with gen.bracketed("else {", "}"):
recur(tree.right)
fn_decl = "{inline} double {name}(double* f) {{".format(
inline=ALWAYS_INLINE,
name=fn)
with gen.bracketed(fn_decl, "}"):
recur(tree)
return gen.lines
class XgbModelParser:
def __init__(self, verbosity = 0):
self._verbosity = verbosity
self.nodeRegex = re.compile("(\d+):\[(.*)<(.+)\]\syes=(.*),no=(.*),missing=.*,gain=(.*),cover=(.*)")
self.leafRegex = re.compile("(\d+):leaf=(.*),cover=(.*)")
def ConstructXgbTree(self, tree):
if tree.node.LeftChild != None:
tree.left = XgbTree(self.xgbNodeList[tree.node.LeftChild])
self.ConstructXgbTree(tree.left)
if tree.node.RightChild != None:
tree.right = XgbTree(self.xgbNodeList[tree.node.RightChild])
self.ConstructXgbTree(tree.right)
def ParseXgbTreeNode(self, line):
node = XgbTreeNode()
if "leaf" in line:
m = self.leafRegex.match(line)
node.Number = int(m.group(1))
node.LeafValue = float(m.group(2))
node.Cover = float(m.group(3))
node.IsLeaf = True
else:
m = self.nodeRegex.match(line)
node.Number = int(m.group(1))
node.Feature = m.group(2)
node.SplitValue = float(m.group(3))
node.LeftChild = int(m.group(4))
node.RightChild = int(m.group(5))
node.Gain = float(m.group(6))
node.Cover = float(m.group(7))
node.IsLeaf = False
return node
def GetXgbModelFromFile(self, fileName, maxTrees):
model = XgbModel(self._verbosity)
self.xgbNodeList = {}
numTree = 0
with open(fileName) as f:
for line in f:
line = line.strip()
if (not line) or line.startswith('booster'):
if any(self.xgbNodeList):
numTree += 1
if self._verbosity >= 2:
sys.stdout.write("Constructing tree #{}\n".format(numTree))
tree = XgbTree(self.xgbNodeList[0])
self.ConstructXgbTree(tree)
model.AddTree(tree)
self.xgbNodeList = {}
if numTree == maxTrees:
if self._verbosity >= 1:
print('maxTrees reached')
break
else:
node = self.ParseXgbTreeNode(line)
if not node:
return None
self.xgbNodeList[node.Number] = node
if any(self.xgbNodeList) and ((maxTrees < 0) or (numTree < maxTrees)):
numTree += 1
if self._verbosity >= 2:
sys.stdout.write("Constructing tree #{}\n".format(numTree))
tree = XgbTree(self.xgbNodeList[0])
self.ConstructXgbTree(tree)
model.AddTree(tree)
self.xgbNodeList = {}
return model
class XgbTreeNode:
def __init__(self):
self.Feature = ''
self.Gain = 0.0
self.Cover = 0.0
self.Number = -1
self.LeftChild = None
self.RightChild = None
self.LeafValue = 0.0
self.SplitValue = 0.0
self.IsLeaf = False
def __lt__(self, other):
return self.Number < other.Number
class XgbTree:
def __init__(self, node):
self.left = None
self.right = None
self.node = node # or node.copy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment