Last active
January 22, 2021 04:57
-
-
Save HeardACat/1529a5bcc730de83f476da2a1096c6fe to your computer and use it in GitHub Desktop.
This is to demonstrate how we could naively convert a tree in River to work with the Shap library. This is in order to start a discussion. https://github.com/online-ml/river/issues/437
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# The goal of this is to try to make use of Shap to explain a tree built in river. | |
# https://github.com/online-ml/river/issues/437 | |
from functools import reduce | |
import operator | |
import numpy as np | |
import pandas as pd | |
import pprint | |
from sklearn import datasets | |
import lightgbm as lgb | |
import matplotlib.pyplot as plt | |
import shap | |
from river import tree | |
from river import stream | |
from river.utils.skmultiflow_utils import normalize_values_in_dict | |
from river.utils.skmultiflow_utils import round_sig_fig | |
def get_all_path(tree): | |
# give the tree and a node, return the path to the node | |
all_paths = [] | |
for ( | |
parent_no, | |
child_no, | |
parent, | |
child, | |
branch_id, | |
) in tree: | |
if parent_no is None: | |
path = [(branch_id, child_no)] | |
all_paths.append(path) | |
else: | |
for p in all_paths: | |
if p[-1][1] == parent_no: | |
new_path = p[:] | |
new_path.append((branch_id, child_no)) | |
all_paths.append(new_path) | |
path_dict = {x[-1][1]: x for x in all_paths} | |
return path_dict | |
def count_leaf(tree): | |
num_leaves = 0 | |
for ( | |
parent_no, | |
child_no, | |
parent, | |
child, | |
branch_id, | |
) in tree: | |
if child.is_leaf(): | |
num_leaves += 1 | |
return num_leaves | |
# https://stackoverflow.com/questions/14692690/access-nested-dictionary-items-via-a-list-of-keys | |
def get_by_path(root, items): | |
"""Access a nested object in root by item sequence.""" | |
return reduce(operator.getitem, items, root) | |
def set_by_path(root, items, value): | |
"""Set a value in a nested object in root by item sequence.""" | |
get_by_path(root, items[:-1])[items[-1]] = value | |
class FakeLightGBMBooster: | |
""" | |
This is to dummy a LightGBM object so we can load a dict object into shap | |
""" | |
def __init__(self, model, model_struct): | |
self.model_struct = model_struct | |
self.model = model | |
self.params = {"objective": "binary"} | |
def dump_model(self): | |
return {"tree_info": [self.model_struct]} | |
def predict(self, X, *args, **kwargs): | |
pred = [] | |
for x in stream.iter_array(X): | |
pred.append(self.model.predict_one(x)) | |
return np.array(pred) | |
@classmethod | |
def __instancecheck__(cls, instance): | |
return isinstance(instance, lgb.basic.Booster) | |
@property | |
def __class__(self): | |
return lgb.basic.Booster | |
def dump_tree_model(htc: tree.HoeffdingTreeClassifier): | |
num_leaves = count_leaf(htc._tree_root.iter_edges()) | |
feature_names_mapping = {k: v for v, k in enumerate(dataset.feature_names)} | |
path_dict = get_all_path(htc._tree_root.iter_edges()) | |
left_right = {0: "left_child", 1: "right_child"} | |
info_json = None | |
# lightgbm's json dump resets the indices for the leaf and node, we'll replicate that here | |
# rather than using `child_no` | |
curr_leaf_index = 0 | |
curr_node_indx = 0 | |
for ( | |
parent_no, | |
child_no, | |
parent, | |
child, | |
branch_id, | |
) in htc._tree_root.iter_edges(): | |
# print(parent_no, child_no) | |
# lightgbm structure reports the stats for all nodes | |
pred = child.stats | |
max_class = max(pred, key=pred.get) | |
# copy from river...assume its a classifier for now | |
sum_votes = sum(pred.values()) | |
probas = max_class | |
if sum_votes > 0: | |
pred = normalize_values_in_dict(pred, factor=sum_votes, inplace=False) | |
probas = {c: round_sig_fig(proba) for c, proba in pred.items()}[ | |
max(pred, key=pred.get) | |
] | |
internal_value = probas | |
stat_dict = { | |
"internal_value": internal_value, | |
"internal_count": max(int(sum_votes), 1), | |
"leaf_value": internal_value, | |
"leaf_count": max(int(sum_votes), 1), | |
"default_left": True, | |
"_parent": parent_no, | |
"_node": child_no, | |
"_branch_id": branch_id, | |
} | |
if child.is_leaf(): | |
# stat_dict['leaf_index'] = child_no # - this is lightgbm style | |
# i don't think it matters in shap, and may be preferable to keep an internal index like this | |
stat_dict["leaf_index"] = curr_leaf_index # hack for shap | |
curr_leaf_index += 1 | |
else: | |
# to extract the condition of the bracnh, you need to do something like | |
# child.split_test.describe_condtion_for_branch(branch_id) | |
# but we don't need to do that... | |
stat_dict["split_index"] = curr_node_indx | |
curr_node_indx += 1 | |
condition_extract = child.split_test.describe_condition_for_branch(branch_id) | |
stat_dict["split_feature"] = feature_names_mapping[child.split_test._att_idx] | |
stat_dict["threshold"] = child.split_test._att_value | |
if condition_extract is not None: | |
# convert condtion to: split_feature, decision_type, threshold | |
output = condition_extract.strip().rsplit(" ", 1) | |
condition_extract, threshold = output[0], output[1] | |
output = condition_extract.strip().rsplit(" ", 1) | |
feature_name, decision_type = output[0], output[1] | |
split_feature = feature_names_mapping[feature_name.strip()] | |
# this dump presumes binary classes, and convention is that if decision type contains "=" or "<" it is gets routes as the left node. | |
stat_dict["split_feature"] = split_feature | |
stat_dict["threshold"] = threshold | |
stat_dict["decision_type"] = decision_type | |
stat_dict["_child_type"] = ( | |
"left_child" | |
if "=" in decision_type or "<" in decision_type | |
else "right_child" | |
) | |
if info_json is None: | |
info_json = stat_dict | |
else: | |
# figure out the path to the node, and attach it appropriately | |
# we also need to fill in the split logic upstream (maybe) | |
path = path_dict[child_no] | |
# unravel path to get key | |
map_keys = [left_right[x[0]] for x in path if x[0] is not None] | |
set_by_path(info_json, map_keys, stat_dict) | |
tree_struct = { | |
"tree_structure": info_json, | |
"num_leaves": num_leaves, | |
} | |
return tree_struct | |
################# | |
# if __name__ == "__main__": blah blah blah | |
# Load the data | |
dataset = datasets.load_breast_cancer() | |
X, y = dataset.data, dataset.target | |
lgdt = lgb.LGBMClassifier(n_estimators=1) | |
htc = tree.HoeffdingTreeClassifier() | |
lgdt.fit(dataset.data, dataset.target) | |
json_model = ( | |
lgdt.booster_.dump_model() | |
) # we can use tree_info dictionary type + SingleTree to use shap. | |
tree_info_json = json_model["tree_info"][ | |
0 | |
] # we index 1 because we only care about the singletree format | |
# we use the keys, tree_index, num_leaves, num_cat, shrinkage, tree_structure | |
# shap only cares about tree_structure, and num_leaves, to get the num_parents. | |
tree_struct = tree_info_json["tree_structure"] | |
print("This is an example of the LightGBM tree dump:") | |
pprint.pprint(tree_struct) | |
print("---------") | |
# it then uses: left_child, split index, right_child, leaf_index, threshold, internal_value, internal_count, | |
# internal_count: number of records from the training data that fall into this non-leaf node | |
# internal_value: raw predicted value that would be produced by this node if it was a leaf node | |
for _ in range(10): | |
for xi, yi in stream.iter_sklearn_dataset(dataset): | |
# print(xi, yi) | |
htc.learn_one(xi, yi) | |
# this is the root of the tree. | |
# we need to examine how to extract a tree from the learning node objects | |
# for example, the method `iter_edges` may give us a hint. | |
# as shap only supports binary trees, and other constraints lets see how this | |
# would work in this setting... | |
# we can create a json_tree info object similar to lightGBM for exporting | |
# purposes. | |
# htc._tree_root | |
# https://github.com/slundberg/shap/blob/474fc74bc0a93911879248ee9f651dcea67270fd/shap/explainers/_tree.py#L1119 | |
# https://github.com/online-ml/river/blob/bf012736ee4bb5152d2e20ab11beedc6957e8294/river/tree/_base_tree.py | |
print("Replicating this tree format here...") | |
tree_struct = dump_tree_model(htc) | |
pprint.pprint(tree_struct) | |
print("---------") | |
fake_model = FakeLightGBMBooster(htc, tree_struct) | |
explainer = shap.TreeExplainer(fake_model) | |
# explainer.model --> TreeEnsemble object | |
# explainer.model.trees --> the internal tree object. | |
explainer.model.tree_output = "raw_value" | |
explainer.model.objective = None | |
explainer.model.model_type = "internal" | |
# TODO: as check additivity is set to False - I can't guarentee that this makes sense in this context | |
shap_values = explainer.shap_values(dataset.data, check_additivity=False) | |
# shap.force_plot(explainer.expected_value, shap_values, dataset.data).html() | |
shap.summary_plot(shap_values, dataset.feature_names, show=False) | |
f = plt.gcf() | |
f.savefig("output.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment