Skip to content

Commit

Permalink
Add scripts for pdp and shap comparison figures
Browse files Browse the repository at this point in the history
  • Loading branch information
agosiewska committed Feb 15, 2021
1 parent 961de09 commit 9e21ed5
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 0 deletions.
Binary file modified figures/pdps.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
74 changes: 74 additions & 0 deletions scripts/pdp_times.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
library(gridExtra)

############# DALEX

library(DALEX)
code_to_eval_DALEX <- 'data(titanic_imputed, package = "DALEX")
ranger_model <- ranger::ranger(survived~., data = titanic_imputed, classification = TRUE, probability = TRUE)
explainer_ranger <- explain(ranger_model,
data = titanic_imputed, y = titanic_imputed$survived,
label = "Ranger Model", verbose = FALSE)
pdp_ranger <- model_profile(explainer_ranger, variables = "fare", type = "partial")
plot(pdp_ranger)'

system.time(eval(expr = parse(text=code_to_eval_DALEX)))[3]

############# flashlight

library(flashlight)
library(MetricsWeighted)
code_to_eval_flashlight <- 'data(titanic_imputed, package = "DALEX")
ranger_model <- ranger::ranger(survived~., data = titanic_imputed, classification = TRUE, probability = TRUE)
custom_predict <- function(X.model, new_data) {
predict(X.model, new_data)$predictions[,2]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
metrics = list(auc = AUC), predict_function = custom_predict)
pdp <- light_profile(fl, v = "fare", type = "partial dependence")
plot(pdp)'

system.time(eval(expr = parse(text=code_to_eval_flashlight)))[3]


############# iml
library(iml)
code_to_eval_iml <- 'data(titanic_imputed, package = "DALEX")
rf_model <- ranger::ranger(survived~., data = titanic_imputed, classification = TRUE, probability = TRUE)
X <- titanic_imputed[which(names(titanic_imputed) != "survived")]
pred_fun <- function(X.model, newdata) {
predict(X.model, newdata)$predictions[,2]
}
predictor <- Predictor$new(rf_model, data = X, y = titanic_imputed$survived, predict.function = pred_fun)
pdp <- FeatureEffect$new(predictor, feature = "fare", method = "pdp")
plot(pdp)'

system.time(eval(expr = parse(text=code_to_eval_iml)))[3]


############# pdp
library(pdp)
library(randomForest)

code_to_eval_pdp <- 'data(titanic_imputed, package = "DALEX")
rf_model <- randomForest(factor(survived)~., data = titanic_imputed[1:10,])
pred_fun <- function(object, newdata) predict(object, newdata, type = "prob")[,2]
rf_pdp <- partial(rf_model, pred.var = c("fare"), pred.fun = pred_fun )
plotPartial(rf_pdp)'

system.time(eval(expr = parse(text=code_to_eval_pdp)))[3]



#### generate figure

plot_DALEX <- eval(expr = parse(text=code_to_eval_DALEX))
plot_flashlight <- eval(expr = parse(text=code_to_eval_flashlight))
plot_iml <- eval(expr = parse(text=code_to_eval_iml))
plot_pdp <- eval(expr = parse(text=code_to_eval_pdp))

png("figures/pdps2.png", height = 900, width = 1200, units="px")
grid.arrange(plot_DALEX, plot_flashlight, plot_iml, plot_pdp, ncol = 2)
dev.off()

77 changes: 77 additions & 0 deletions scripts/shap_times.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
library(gridExtra)

############# DALEX

library(DALEX)
code_to_eval_DALEX <- 'data(titanic_imputed, package = "DALEX")
ranger_model <- ranger::ranger(survived~., data = titanic_imputed, classification = TRUE, probability = TRUE)
explainer_ranger <- DALEX::explain(ranger_model, data = titanic_imputed, y = titanic_imputed$survived, label = "Ranger Model", verbose = FALSE)
shap_ranger <- predict_parts(explainer_ranger, new_observation = titanic_imputed[1,], type = "shap", B = 50)
plot(shap_ranger)'

system.time(eval(expr = parse(text=code_to_eval_DALEX)))[3]

############# fastshap

library(fastshap)
code_to_eval_fastshap <- 'data(titanic_imputed, package = "DALEX")
ranger_model <- ranger::ranger(survived~., data = titanic_imputed, classification = TRUE, probability = TRUE)
pred_fun <- function(X.model, newdata) {
predict(X.model, newdata)$predictions[,2]
}
shap <- explain(ranger_model, X = titanic_imputed, pred_wrapper = pred_fun, nsim = 50)
library(ggplot2)
autoplot(shap, type = "contribution", row_num = 1)'

system.time(eval(expr = parse(text=code_to_eval_fastshap)))[3]


############# iml
library(iml)
code_to_eval_iml <- 'data(titanic_imputed, package = "DALEX")
rf_model <- ranger::ranger(survived~., data = titanic_imputed, classification = TRUE, probability = TRUE)
X <- titanic_imputed[which(names(titanic_imputed) != "survived")]
pred_fun <- function(X.model, newdata) {
predict(X.model, newdata)$predictions[,2]
}
predictor <- Predictor$new(rf_model, data = X, y = titanic_imputed$survived, predict.function = pred_fun)
shapley <- Shapley$new(predictor, x.interest = X[1, ], sample.size = 50)
plot(shapley)'

system.time(eval(expr = parse(text=code_to_eval_iml)))[3]


############# shapper
library(randomForest)
library(shapper)

code_to_eval_shapper <- 'data(titanic_imputed, package = "DALEX")
rf_model <- ranger::ranger(survived~., data = titanic_imputed, classification = TRUE, probability = TRUE)
pred_fun <- function(X.model, newdata) {
predict(X.model, newdata)$predictions[,2]
}
ive_rf <- individual_variable_effect(rf_model,
predict_function = pred_fun,
data = titanic_imputed[,-8],
new_observation = titanic_imputed[1, -8],
nsamples = 50)
plot(ive_rf)'

system.time(eval(expr = parse(text=code_to_eval_shapper)))[3]



#### generate figure

plot_DALEX <- eval(expr = parse(text=code_to_eval_DALEX))
plot_fastshap <- eval(expr = parse(text=code_to_eval_fastshap))
plot_iml <- eval(expr = parse(text=code_to_eval_iml))
plot_shapper <- eval(expr = parse(text=code_to_eval_shapper))

png("figures/pdps2.png", height = 900, width = 1200, units="px")
grid.arrange(plot_DALEX, plot_flashlight, plot_iml, plot_pdp, ncol = 2)
dev.off()

0 comments on commit 9e21ed5

Please sign in to comment.