Skip to content

Commit

Permalink
Get plots tested on travis
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Feb 11, 2018
1 parent 2a57a48 commit 3b0caaf
Show file tree
Hide file tree
Showing 8 changed files with 628 additions and 51,046 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ shap.egg-info
notebooks/local_scratch
docs
.ipynb_checkpoints
.eggs
5 changes: 5 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@ install:
- pip install .
# command to run tests
script: python setup.py nosetests

before_script: # configure a headless display to test plot generation
- "export DISPLAY=:99.0"
- "sh -e /etc/init.d/xvfb start"
- sleep 3 # give xvfb some time to start
54 changes: 37 additions & 17 deletions notebooks/Basic SHAP Interaction Value Example in XGBoost.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

50,430 changes: 0 additions & 50,430 deletions notebooks/Front page example with XGBoost.ipynb

This file was deleted.

Binary file removed notebooks/data/nhanes_age_sex_interaction.pdf
Binary file not shown.
61 changes: 35 additions & 26 deletions shap/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import matplotlib.pyplot as pl
import matplotlib
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import MaxNLocator

cdict1 = {
'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
Expand All @@ -25,13 +26,13 @@
'alpha': ((0.0, 1, 1),
(0.5, 0.3, 0.3),
(1.0, 1, 1))
}
} # #1E88E5 -> #ff0052
red_blue = LinearSegmentedColormap('RedBlue', cdict1)
except ImportError:
pass

def dependence_plot(ind, shap_values, features, feature_names=None, display_features=None,
interaction_index="auto", color="#ff0052", axis_color="#333333",
interaction_index="auto", color="#1E88E5", axis_color="#333333",
dot_size=16, alpha=1, title=None, show=True):
"""
Create a SHAP dependence plot, colored by an interaction feature.
Expand All @@ -53,7 +54,7 @@ def dependence_plot(ind, shap_values, features, feature_names=None, display_feat
display_features : numpy.array or pandas.DataFrame
Matrix of feature values for visual display (such as strings instead of coded values)
interaction_index : "auto" or int
interaction_index : "auto", None, or int
The index of the feature used to color the plot.
"""

Expand Down Expand Up @@ -103,7 +104,11 @@ def convert_name(ind):
ind1, proj_shap_values, features, feature_names=feature_names,
interaction_index=ind2, display_features=display_features, show=False
)
pl.ylabel("SHAP interaction value for\n"+feature_names[ind1]+" and "+feature_names[ind2])
if ind1 == ind2:
pl.ylabel("SHAP main effect value for\n"+feature_names[ind1])
else:
pl.ylabel("SHAP interaction value for\n"+feature_names[ind1]+" and "+feature_names[ind2])

if show:
pl.show()
return
Expand Down Expand Up @@ -140,7 +145,7 @@ def convert_name(ind):
cname_map[cd[i]] = cv[i]
cnames = list(cname_map.keys())
categorical_interaction = True
elif clow % 1 == 0 and chigh % 1 == 0:
elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:,interaction_index])) < 50:
categorical_interaction = True

# discritize colors for categorical features
Expand All @@ -149,33 +154,37 @@ def convert_name(ind):
bounds = np.linspace(clow, chigh, chigh-clow+2)
color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)

# the actual scatter plot, TODO: adapt the dot_size to the number of data points
# the actual scatter plot, TODO: adapt the dot_size to the number of data points?
pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:,interaction_index], cmap=red_blue,
alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)

# draw the color bar
norm = None
if type(cd[0]) == str:
tick_positions = [cname_map[n] for n in cnames]
if len(tick_positions) == 2:
tick_positions[0] -= 0.25
tick_positions[1] += 0.25
cb = pl.colorbar(ticks=tick_positions)
cb.set_ticklabels(cnames)
if interaction_index != ind:
# draw the color bar
norm = None
if type(cd[0]) == str:
tick_positions = [cname_map[n] for n in cnames]
if len(tick_positions) == 2:
tick_positions[0] -= 0.25
tick_positions[1] += 0.25
cb = pl.colorbar(ticks=tick_positions)
cb.set_ticklabels(cnames)

else:
cb = pl.colorbar()
cb.set_label(feature_names[interaction_index], size=13)
cb.ax.tick_params(labelsize=11)
if categorical_interaction:
cb.ax.tick_params(length=0)
cb.set_alpha(1)
cb.outline.set_visible(False)
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
cb.ax.set_aspect((bbox.height-0.7)*20)
else:
cb = pl.colorbar()
cb.set_label(feature_names[interaction_index], size=13)
cb.ax.tick_params(labelsize=11)
if categorical_interaction:
cb.ax.tick_params(length=0)
cb.set_alpha(1)
cb.outline.set_visible(False)
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
cb.ax.set_aspect((bbox.height-0.7)*20)

# make the plot more readable
pl.gcf().set_size_inches(7.5, 5)
if interaction_index != ind:
pl.gcf().set_size_inches(7.5, 5)
else:
pl.gcf().set_size_inches(6, 5)
pl.xlabel(name, color=axis_color, fontsize=13)
pl.ylabel("SHAP value for\n"+name, color=axis_color, fontsize=13)
if title != None:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def test_front_page_xgboost():
shap.visualize(shap_values, X)

# create a SHAP dependence plot to show the effect of a single feature across the whole dataset
#shap.dependence_plot(5, shap_values, X, show=False)
#shap.dependence_plot("RM", shap_values, X, show=False)
shap.dependence_plot(5, shap_values, X, show=False)
shap.dependence_plot("RM", shap_values, X, show=False)

# summarize the effects of all the features
#shap.summary_plot(shap_values, X, show=False)
shap.summary_plot(shap_values, X, show=False)

0 comments on commit 3b0caaf

Please sign in to comment.