-
-
Save hqucms/56844f4d1e04757704f6afcdaa6f65a8 to your computer and use it in GitHub Desktop.
import re | |
import xml.etree.cElementTree as ET | |
regex_float_pattern = r'[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?' | |
def build_tree(xgtree, base_xml_element, var_indices): | |
parent_element_dict = {'0':base_xml_element} | |
pos_dict = {'0':'s'} | |
for line in xgtree.split('\n'): | |
if not line: continue | |
if ':leaf=' in line: | |
#leaf node | |
result = re.match(r'(\t*)(\d+):leaf=({0})$'.format(regex_float_pattern), line) | |
if not result: | |
print line | |
depth = result.group(1).count('\t') | |
inode = result.group(2) | |
res = result.group(3) | |
node_elementTree = ET.SubElement(parent_element_dict[inode], "Node", pos=str(pos_dict[inode]), | |
depth=str(depth), NCoef="0", IVar="-1", Cut="0.0e+00", cType="1", res=str(res), rms="0.0e+00", purity="0.0e+00", nType="-99") | |
else: | |
#\t\t3:[var_topcand_mass<138.19] yes=7,no=8,missing=7 | |
result = re.match(r'(\t*)([0-9]+):\[(?P<var>.+)<(?P<cut>{0})\]\syes=(?P<yes>\d+),no=(?P<no>\d+)'.format(regex_float_pattern),line) | |
if not result: | |
print line | |
depth = result.group(1).count('\t') | |
inode = result.group(2) | |
var = result.group('var') | |
cut = result.group('cut') | |
lnode = result.group('yes') | |
rnode = result.group('no') | |
pos_dict[lnode] = 'l' | |
pos_dict[rnode] = 'r' | |
node_elementTree = ET.SubElement(parent_element_dict[inode], "Node", pos=str(pos_dict[inode]), | |
depth=str(depth), NCoef="0", IVar=str(var_indices[var]), Cut=str(cut), | |
cType="1", res="0.0e+00", rms="0.0e+00", purity="0.0e+00", nType="0") | |
parent_element_dict[lnode] = node_elementTree | |
parent_element_dict[rnode] = node_elementTree | |
def convert_model(model, input_variables, output_xml): | |
NTrees = len(model) | |
var_list = input_variables | |
var_indices = {} | |
# <MethodSetup> | |
MethodSetup = ET.Element("MethodSetup", Method="BDT::BDT") | |
# <Variables> | |
Variables = ET.SubElement(MethodSetup, "Variables", NVar=str(len(var_list))) | |
for ind, val in enumerate(var_list): | |
name = val[0] | |
var_type = val[1] | |
var_indices[name] = ind | |
Variable = ET.SubElement(Variables, "Variable", VarIndex=str(ind), Type=val[1], | |
Expression=name, Label=name, Title=name, Unit="", Internal=name, | |
Min="0.0e+00", Max="0.0e+00") | |
# <GeneralInfo> | |
GeneralInfo = ET.SubElement(MethodSetup, "GeneralInfo") | |
Info_Creator = ET.SubElement(GeneralInfo, "Info", name="Creator", value="xgboost2TMVA") | |
Info_AnalysisType = ET.SubElement(GeneralInfo, "Info", name="AnalysisType", value="Classification") | |
# <Options> | |
Options = ET.SubElement(MethodSetup, "Options") | |
Option_NodePurityLimit = ET.SubElement(Options, "Option", name="NodePurityLimit", modified="No").text = "5.00e-01" | |
Option_BoostType = ET.SubElement(Options, "Option", name="BoostType", modified="Yes").text = "Grad" | |
# <Weights> | |
Weights = ET.SubElement(MethodSetup, "Weights", NTrees=str(NTrees), AnalysisType="1") | |
for itree in range(NTrees): | |
BinaryTree = ET.SubElement(Weights, "BinaryTree", type="DecisionTree", boostWeight="1.0e+00", itree=str(itree)) | |
build_tree(model[itree], BinaryTree, var_indices) | |
tree = ET.ElementTree(MethodSetup) | |
tree.write(output_xml) | |
# format it with 'xmllint --format' | |
# example | |
# bst = xgb.train( param, d_train, num_round, watchlist ); | |
# model = bst.get_dump() | |
# convert_model(model,input_variables=[('var1','F'),('var2','I')],output_xml='xgboost.xml') |
@panwarlsweet
Did you specify the feature_names
when defining X_train
, e.g.,
X_train = xgb.DMatrix(X, feature_names=train_vars)
Otherwise xgboost will use 'f0', 'f1', ..., as the names of the input variables.
@hqucms
Thanks a lot for giving the hint.
Actually Instead of putting the variable name "('leadingJet_DeepCSV','F')" I had to use "('f0','F')" and so on. I didn't realize it.
Now it works fine. :)
Thanks again!
@hqucms Do you know anyone who has made a script for the inverse of this process? i.e. TMVA xml to xgboost?
@lgray No I am not aware of that.
@hqucms After talking a bit with @guitargeek this now exists! https://github.com/guitargeek/tmva-to-xgboost
@lgray Nice! Thank you for letting me know:)
Do you know why the output XML does not have the right indentation? I only see one line in the output
Nevermind, I think it would work with:
import xml.dom.minidom
tree = ET.ElementTree(MethodSetup)
tmp_path = f'/tmp/{hash}.xml'
tree.write(tmp_path)
with open(tmp_path) as ifile:
xml_data = ifile.read()
parser = xml.dom.minidom.parseString(xml_data)
xml_pretty = parser.toprettyxml()
with open(output_xml, 'w') as ofile:
ofile.write(xml_pretty)
Do you know why the output XML does not have the right indentation? I only see one line in the output
One can also use xmllint --format
.
Do you know why the output XML does not have the right indentation? I only see one line in the output
One can also use
xmllint --format
.
yeah, but it's better to transform it inside the python code itself, it's easier if you want a well formatted XML file and the other is not of any use.
Hello Huilin,
I am using XGBoost first time. And I want to convert XGB model in xml file. For this, I am using your python code. And when I do this I get an error which I don't understand. The error is :
KeyErrorTraceback (most recent call last)
in ()
1 model = xgb1.get_booster().get_dump()
2 print(len(model))
----> 3 convert_model(model,input_variables=[('leadingJet_DeepCSV','F'),('subleadingJet_DeepCSV','F'),('absCosThetaStar_CS','F'),('absCosTheta_bb','F'),('absCosTheta_gg','F')],output_xml='xgboost.xml')
in convert_model(model, input_variables, output_xml)
70 for itree in range(NTrees):
71 BinaryTree = ET.SubElement(Weights, "BinaryTree", type="DecisionTree", boostWeight="1.0e+00", itree=str(itree))
---> 72 build_tree(model[itree], BinaryTree, var_indices)
73
74 tree = ET.ElementTree(MethodSetup)
in build_tree(xgtree, base_xml_element, var_indices)
32 pos_dict[rnode] = 'r'
33 node_elementTree = ET.SubElement(parent_element_dict[inode], "Node", pos=str(pos_dict[inode]),
---> 34 depth=str(depth), NCoef="0", IVar=str(var_indices[var]), Cut=str(cut),
35 cType="1", res="0.0e+00", rms="0.0e+00", purity="0.0e+00", nType="0")
36 parent_element_dict[lnode] = node_elementTree
KeyError: 'f0'
Could you please help if you have any idea?
I added the following lines in my code:
xgb1.fit(X_train, y_train)
model = xgb1.get_booster().get_dump()
print(len(model))
convert_model(model,input_variables=[('leadingJet_DeepCSV','F'),('subleadingJet_DeepCSV','F'),('absCosThetaStar_CS','F'),('absCosTheta_bb','F'),('absCosTheta_gg','F')],output_xml='xgboost.xml')
Thanks,
Lata