Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Nov 22, 2020
1 parent 4c6391e commit 8e5882b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions shap/explainers/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def __init__(self, model, data=None, data_missing=None, model_output=None):
self.model_output = "probability_doubled" # with predict_proba we need to double the outputs to match
else:
self.model_output = "probability"
output_trees = [[] for i in range(self.num_stacked_models)]
self.trees = []
for p in model._predictors:
for i in range(self.num_stacked_models):
nodes = p[i].nodes
Expand All @@ -754,8 +754,7 @@ def __init__(self, model, data=None, data_missing=None, model_output=None):
"values": np.array([[n[0]] for n in nodes], dtype=np.float64),
"node_sample_weight": np.array([n[1] for n in nodes], dtype=np.float64),
}
output_trees[i].append(SingleTree(tree, data=data, data_missing=data_missing))
self.trees = list(itertools.chain.from_iterable(output_trees))
self.trees.append(SingleTree(tree, data=data, data_missing=data_missing))
self.objective = objective_name_map.get(model.loss, None)
self.tree_output = "log_odds"
elif safe_isinstance(model, ["sklearn.ensemble.GradientBoostingClassifier","sklearn.ensemble._gb.GradientBoostingClassifier", "sklearn.ensemble.gradient_boosting.GradientBoostingClassifier"]):
Expand Down Expand Up @@ -1007,6 +1006,7 @@ def __init__(self, model, data=None, data_missing=None, model_output=None):
self.features[i,:len(self.trees[i].features)] = self.trees[i].features
self.thresholds[i,:len(self.trees[i].thresholds)] = self.trees[i].thresholds
if self.num_stacked_models > 1:
# stack_pos = int(i // (num_trees / self.num_stacked_models))
stack_pos = i % self.num_stacked_models
self.values[i,:len(self.trees[i].values[:,0]),stack_pos] = self.trees[i].values[:,0]
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/explainers/test_gpu_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def idfn(task):
[pytest.param("interventional", marks=pytest.mark.xfail),
"tree_path_dependent"])
def test_gpu_tree_explainer_shap(task, feature_perturbation):
model, X, margin = task
model, X, _ = task
ex = shap.GPUTreeExplainer(model, X, feature_perturbation=feature_perturbation)
ex.shap_values(X, check_additivity=True)

Expand Down

0 comments on commit 8e5882b

Please sign in to comment.