Last active
March 21, 2024 17:25
-
-
Save hqucms/56844f4d1e04757704f6afcdaa6f65a8 to your computer and use it in GitHub Desktop.
Convert xgboost model to TMVA xml format. Not fully tested...
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import xml.etree.cElementTree as ET | |
regex_float_pattern = r'[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?' | |
def build_tree(xgtree, base_xml_element, var_indices): | |
parent_element_dict = {'0':base_xml_element} | |
pos_dict = {'0':'s'} | |
for line in xgtree.split('\n'): | |
if not line: continue | |
if ':leaf=' in line: | |
#leaf node | |
result = re.match(r'(\t*)(\d+):leaf=({0})$'.format(regex_float_pattern), line) | |
if not result: | |
print line | |
depth = result.group(1).count('\t') | |
inode = result.group(2) | |
res = result.group(3) | |
node_elementTree = ET.SubElement(parent_element_dict[inode], "Node", pos=str(pos_dict[inode]), | |
depth=str(depth), NCoef="0", IVar="-1", Cut="0.0e+00", cType="1", res=str(res), rms="0.0e+00", purity="0.0e+00", nType="-99") | |
else: | |
#\t\t3:[var_topcand_mass<138.19] yes=7,no=8,missing=7 | |
result = re.match(r'(\t*)([0-9]+):\[(?P<var>.+)<(?P<cut>{0})\]\syes=(?P<yes>\d+),no=(?P<no>\d+)'.format(regex_float_pattern),line) | |
if not result: | |
print line | |
depth = result.group(1).count('\t') | |
inode = result.group(2) | |
var = result.group('var') | |
cut = result.group('cut') | |
lnode = result.group('yes') | |
rnode = result.group('no') | |
pos_dict[lnode] = 'l' | |
pos_dict[rnode] = 'r' | |
node_elementTree = ET.SubElement(parent_element_dict[inode], "Node", pos=str(pos_dict[inode]), | |
depth=str(depth), NCoef="0", IVar=str(var_indices[var]), Cut=str(cut), | |
cType="1", res="0.0e+00", rms="0.0e+00", purity="0.0e+00", nType="0") | |
parent_element_dict[lnode] = node_elementTree | |
parent_element_dict[rnode] = node_elementTree | |
def convert_model(model, input_variables, output_xml): | |
NTrees = len(model) | |
var_list = input_variables | |
var_indices = {} | |
# <MethodSetup> | |
MethodSetup = ET.Element("MethodSetup", Method="BDT::BDT") | |
# <Variables> | |
Variables = ET.SubElement(MethodSetup, "Variables", NVar=str(len(var_list))) | |
for ind, val in enumerate(var_list): | |
name = val[0] | |
var_type = val[1] | |
var_indices[name] = ind | |
Variable = ET.SubElement(Variables, "Variable", VarIndex=str(ind), Type=val[1], | |
Expression=name, Label=name, Title=name, Unit="", Internal=name, | |
Min="0.0e+00", Max="0.0e+00") | |
# <GeneralInfo> | |
GeneralInfo = ET.SubElement(MethodSetup, "GeneralInfo") | |
Info_Creator = ET.SubElement(GeneralInfo, "Info", name="Creator", value="xgboost2TMVA") | |
Info_AnalysisType = ET.SubElement(GeneralInfo, "Info", name="AnalysisType", value="Classification") | |
# <Options> | |
Options = ET.SubElement(MethodSetup, "Options") | |
Option_NodePurityLimit = ET.SubElement(Options, "Option", name="NodePurityLimit", modified="No").text = "5.00e-01" | |
Option_BoostType = ET.SubElement(Options, "Option", name="BoostType", modified="Yes").text = "Grad" | |
# <Weights> | |
Weights = ET.SubElement(MethodSetup, "Weights", NTrees=str(NTrees), AnalysisType="1") | |
for itree in range(NTrees): | |
BinaryTree = ET.SubElement(Weights, "BinaryTree", type="DecisionTree", boostWeight="1.0e+00", itree=str(itree)) | |
build_tree(model[itree], BinaryTree, var_indices) | |
tree = ET.ElementTree(MethodSetup) | |
tree.write(output_xml) | |
# format it with 'xmllint --format' | |
# example | |
# bst = xgb.train( param, d_train, num_round, watchlist ); | |
# model = bst.get_dump() | |
# convert_model(model,input_variables=[('var1','F'),('var2','I')],output_xml='xgboost.xml') |
@lgray Nice! Thank you for letting me know:)
Do you know why the output XML does not have the right indentation? I only see one line in the output
Nevermind, I think it would work with:
import xml.dom.minidom
tree = ET.ElementTree(MethodSetup)
tmp_path = f'/tmp/{hash}.xml'
tree.write(tmp_path)
with open(tmp_path) as ifile:
xml_data = ifile.read()
parser = xml.dom.minidom.parseString(xml_data)
xml_pretty = parser.toprettyxml()
with open(output_xml, 'w') as ofile:
ofile.write(xml_pretty)
Do you know why the output XML does not have the right indentation? I only see one line in the output
One can also use xmllint --format
.
Do you know why the output XML does not have the right indentation? I only see one line in the output
One can also use
xmllint --format
.
yeah, but it's better to transform it inside the python code itself, it's easier if you want a well formatted XML file and the other is not of any use.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@hqucms After talking a bit with @guitargeek this now exists! https://github.com/guitargeek/tmva-to-xgboost