-
-
Save fanwei918/cdd6e75aecbefaa38ec8c2cb3c1fca30 to your computer and use it in GitHub Desktop.
Convert XGBoost model to C header
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
_comparer = None | |
import contextlib | |
#ALWAYS_INLINE = "__attribute__((__always_inline__))" | |
ALWAYS_INLINE = "ALWAYS_INLINE" | |
class CodeGenerator(object): | |
def __init__(self): | |
self._lines = [] | |
self._indent = 0 | |
@property | |
def lines(self): | |
return self._lines | |
def write(self, line): | |
self._lines.append(" " * self._indent + line) | |
@contextlib.contextmanager | |
def bracketed(self, preamble, postamble): | |
assert self._indent >= 0 | |
self.write(preamble) | |
self._indent += 1 | |
yield | |
self._indent -= 1 | |
self.write(postamble) | |
class XgbModel: | |
def __init__(self, verbosity = 0): | |
self._verbosity = verbosity | |
self.XgbTrees = [] | |
self._treeIndex = 0 | |
self._maxDeepening = 0 | |
self._pathMemo = [] | |
self._maxInteractionDepth = 0 | |
def AddTree(self, tree): | |
self.XgbTrees.append(tree) | |
def code_gen_ensemble(self,fn,gen=None): | |
if gen is None: | |
gen = CodeGenerator() | |
num_trees = 0 | |
for i, tree in enumerate(self.XgbTrees): | |
name = "{name}_{index}".format(name=fn+'_boost', index=i) | |
self.code_gen_tree(tree,name, gen) | |
num_trees+=1 | |
fn_decl = "{inline} double {name}(double* f) {{".format(inline=ALWAYS_INLINE,name=fn) | |
with gen.bracketed(fn_decl, "}"): | |
gen.write("double result = 0.;") | |
for i in range(num_trees): | |
increment = "result += {name}_{index}(f);".format( | |
name=fn+'_boost',index=i) | |
gen.write(increment) | |
gen.write("return result;") | |
return gen.lines | |
def xgb_to_c(self,fn): | |
lines = self.code_gen_ensemble(fn=fn) | |
header = "#pragma once\n"+"namespace atm { namespace trader { namespace models {\n\n" | |
footer = "\n\n }}}" | |
assert lines is not None | |
code = "\n".join(lines) | |
return header + code + footer | |
def code_gen_tree(self,tree,fn='boost', gen=None): | |
if gen is None: | |
gen = CodeGenerator() | |
def recur(tree): | |
if tree.node.IsLeaf: | |
gen.write("return {0};".format(tree.node.LeafValue)) | |
return | |
branch = "if (f[{feature}] <= {threshold}f) {{".format( | |
feature=tree.node.Feature[1:], | |
threshold=tree.node.SplitValue) | |
with gen.bracketed(branch, "}"): | |
recur(tree.left) | |
with gen.bracketed("else {", "}"): | |
recur(tree.right) | |
fn_decl = "{inline} double {name}(double* f) {{".format( | |
inline=ALWAYS_INLINE, | |
name=fn) | |
with gen.bracketed(fn_decl, "}"): | |
recur(tree) | |
return gen.lines | |
class XgbModelParser: | |
def __init__(self, verbosity = 0): | |
self._verbosity = verbosity | |
self.nodeRegex = re.compile("(\d+):\[(.*)<(.+)\]\syes=(.*),no=(.*),missing=.*,gain=(.*),cover=(.*)") | |
self.leafRegex = re.compile("(\d+):leaf=(.*),cover=(.*)") | |
def ConstructXgbTree(self, tree): | |
if tree.node.LeftChild != None: | |
tree.left = XgbTree(self.xgbNodeList[tree.node.LeftChild]) | |
self.ConstructXgbTree(tree.left) | |
if tree.node.RightChild != None: | |
tree.right = XgbTree(self.xgbNodeList[tree.node.RightChild]) | |
self.ConstructXgbTree(tree.right) | |
def ParseXgbTreeNode(self, line): | |
node = XgbTreeNode() | |
if "leaf" in line: | |
m = self.leafRegex.match(line) | |
node.Number = int(m.group(1)) | |
node.LeafValue = float(m.group(2)) | |
node.Cover = float(m.group(3)) | |
node.IsLeaf = True | |
else: | |
m = self.nodeRegex.match(line) | |
node.Number = int(m.group(1)) | |
node.Feature = m.group(2) | |
node.SplitValue = float(m.group(3)) | |
node.LeftChild = int(m.group(4)) | |
node.RightChild = int(m.group(5)) | |
node.Gain = float(m.group(6)) | |
node.Cover = float(m.group(7)) | |
node.IsLeaf = False | |
return node | |
def GetXgbModelFromFile(self, fileName, maxTrees): | |
model = XgbModel(self._verbosity) | |
self.xgbNodeList = {} | |
numTree = 0 | |
with open(fileName) as f: | |
for line in f: | |
line = line.strip() | |
if (not line) or line.startswith('booster'): | |
if any(self.xgbNodeList): | |
numTree += 1 | |
if self._verbosity >= 2: | |
sys.stdout.write("Constructing tree #{}\n".format(numTree)) | |
tree = XgbTree(self.xgbNodeList[0]) | |
self.ConstructXgbTree(tree) | |
model.AddTree(tree) | |
self.xgbNodeList = {} | |
if numTree == maxTrees: | |
if self._verbosity >= 1: | |
print('maxTrees reached') | |
break | |
else: | |
node = self.ParseXgbTreeNode(line) | |
if not node: | |
return None | |
self.xgbNodeList[node.Number] = node | |
if any(self.xgbNodeList) and ((maxTrees < 0) or (numTree < maxTrees)): | |
numTree += 1 | |
if self._verbosity >= 2: | |
sys.stdout.write("Constructing tree #{}\n".format(numTree)) | |
tree = XgbTree(self.xgbNodeList[0]) | |
self.ConstructXgbTree(tree) | |
model.AddTree(tree) | |
self.xgbNodeList = {} | |
return model | |
class XgbTreeNode: | |
def __init__(self): | |
self.Feature = '' | |
self.Gain = 0.0 | |
self.Cover = 0.0 | |
self.Number = -1 | |
self.LeftChild = None | |
self.RightChild = None | |
self.LeafValue = 0.0 | |
self.SplitValue = 0.0 | |
self.IsLeaf = False | |
def __lt__(self, other): | |
return self.Number < other.Number | |
class XgbTree: | |
def __init__(self, node): | |
self.left = None | |
self.right = None | |
self.node = node # or node.copy() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment