Skip to content

Instantly share code, notes, and snippets.

@0x0L
Last active March 4, 2018 00:09
Show Gist options
  • Save 0x0L/de41bcab16b07f55a42ac1ecfe0006a1 to your computer and use it in GitHub Desktop.
Save 0x0L/de41bcab16b07f55a42ac1ecfe0006a1 to your computer and use it in GitHub Desktop.
public class LightGBMTree
{
public int[] split_feature;
public double[] threshold;
public int[] left_child;
public int[] right_child;
public double[] leaf_value;
public double Predict(double[] features)
{
var node = 0;
for (;;) {
var left = features[split_feature[node]] < threshold[node];
node = left ? left_child[node] : right_child[node];
if (node < 0)
return leaf_value[-node-1];
}
}
}
import re
_FIELDS = [
('split_feature', int),
('threshold', float),
('decision_type', int),
('left_child', int),
('right_child', int),
('leaf_value', float),
('cat_boundaries', int),
('cat_threshold', int),
]
def read_model(filename):
with open('model.txt') as f:
content = f.read()
trees = []
for i in [m.start() for m in re.finditer('Tree=', content)]:
c = content[i:].split('\n\n')[0]
t = {}
for field, ctor in _FIELDS:
q = re.search(field + '=([^\n]+)\n', c)
t[field] = [ctor(x) for x in q.group(1).split()] if q else None
trees.append(t)
return trees
def _predict(tree, features):
node = 0
while True:
if tree['decision_type'][node] & 1:
int_fval = int(features[tree['split_feature'][node]])
cat_idx = int(tree['threshold'][node])
cat_boundaries = tree['cat_boundaries']
cat_threshold = tree['cat_threshold']
rhs = 32 * (cat_boundaries[cat_idx + 1] - cat_boundaries[cat_idx])
left = int_fval < rhs
z = (cat_threshold[cat_boundaries[cat_idx] + int_fval // 32] >> (int_fval & 31)) & 1
left &= z
else:
left = features[tree['split_feature'][node]] <= tree['threshold'][node]
node = tree['left_child' if left else 'right_child'][node]
if node < 0:
return tree['leaf_value'][-node-1]
def predict(trees, features):
return sum(_predict(t, features) for t in trees)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment