forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[R-package] GPL2 dependency reduction and some fixes (dmlc#1401)
* [R] do not remove zero coefficients from gblinear dump * [R] switch from stringr to stringi * fix dmlc#1399 * [R] separate ggplot backend, add base r graphics, cleanup, more plots, tests * add missing include in amalgamation - fixes building R package in linux * add forgotten file * [R] fix DESCRIPTION * [R] fix travis check issue and some cleanup
- Loading branch information
Showing
19 changed files
with
541 additions
and
305 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# ggplot backend for the xgboost plotting facilities | ||
|
||
|
||
#' @rdname xgb.plot.importance | ||
#' @export | ||
xgb.ggplot.importance <- function(importance_matrix = NULL, top_n = NULL, measure = NULL, | ||
rel_to_first = FALSE, n_clusters = c(1:10), ...) { | ||
|
||
importance_matrix <- xgb.plot.importance(importance_matrix, top_n = top_n, measure = measure, | ||
rel_to_first = rel_to_first, plot = FALSE, ...) | ||
if (!requireNamespace("ggplot2", quietly = TRUE)) { | ||
stop("ggplot2 package is required", call. = FALSE) | ||
} | ||
if (!requireNamespace("Ckmeans.1d.dp", quietly = TRUE)) { | ||
stop("Ckmeans.1d.dp package is required", call. = FALSE) | ||
} | ||
|
||
clusters <- suppressWarnings( | ||
Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix$Importance, n_clusters) | ||
) | ||
importance_matrix[, Cluster := as.character(clusters$cluster)] | ||
|
||
plot <- | ||
ggplot2::ggplot(importance_matrix, | ||
ggplot2::aes(x = factor(Feature, levels = rev(Feature)), y = Importance, width = 0.05), | ||
environment = environment()) + | ||
ggplot2::geom_bar(ggplot2::aes(fill = Cluster), stat = "identity", position = "identity") + | ||
ggplot2::coord_flip() + | ||
ggplot2::xlab("Features") + | ||
ggplot2::ggtitle("Feature importance") + | ||
ggplot2::theme(plot.title = ggplot2::element_text(lineheight = .9, face = "bold"), | ||
panel.grid.major.y = ggplot2::element_blank()) | ||
return(plot) | ||
} | ||
|
||
|
||
#' @rdname xgb.plot.deepness | ||
#' @export | ||
xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight")) { | ||
|
||
if (!requireNamespace("ggplot2", quietly = TRUE)) | ||
stop("ggplot2 package is required for plotting the graph deepness.", call. = FALSE) | ||
|
||
which <- match.arg(which) | ||
|
||
dt_depths <- xgb.plot.deepness(model = model, plot = FALSE) | ||
dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth] | ||
setkey(dt_summaries, 'Depth') | ||
|
||
if (which == "2x1") { | ||
p1 <- | ||
ggplot2::ggplot(dt_summaries) + | ||
ggplot2::geom_bar(ggplot2::aes(x = Depth, y = N), stat = "Identity") + | ||
ggplot2::xlab("") + | ||
ggplot2::ylab("Number of leafs") + | ||
ggplot2::ggtitle("Model complexity") + | ||
ggplot2::theme( | ||
plot.title = ggplot2::element_text(lineheight = 0.9, face = "bold"), | ||
panel.grid.major.y = ggplot2::element_blank(), | ||
axis.ticks = ggplot2::element_blank(), | ||
axis.text.x = ggplot2::element_blank() | ||
) | ||
|
||
p2 <- | ||
ggplot2::ggplot(dt_summaries) + | ||
ggplot2::geom_bar(ggplot2::aes(x = Depth, y = Cover), stat = "Identity") + | ||
ggplot2::xlab("Leaf depth") + | ||
ggplot2::ylab("Weighted cover") | ||
|
||
multiplot(p1, p2, cols = 1) | ||
return(invisible(list(p1, p2))) | ||
|
||
} else if (which == "max.depth") { | ||
p <- | ||
ggplot2::ggplot(dt_depths[, max(Depth), Tree]) + | ||
ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1), | ||
height = 0.15, alpha=0.4, size=3, stroke=0) + | ||
ggplot2::xlab("tree #") + | ||
ggplot2::ylab("Max tree leaf depth") | ||
return(p) | ||
|
||
} else if (which == "med.depth") { | ||
p <- | ||
ggplot2::ggplot(dt_depths[, median(as.numeric(Depth)), Tree]) + | ||
ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1), | ||
height = 0.15, alpha=0.4, size=3, stroke=0) + | ||
ggplot2::xlab("tree #") + | ||
ggplot2::ylab("Median tree leaf depth") | ||
return(p) | ||
|
||
} else if (which == "med.weight") { | ||
p <- | ||
ggplot2::ggplot(dt_depths[, median(abs(Weight)), Tree]) + | ||
ggplot2::geom_point(ggplot2::aes(x = Tree, y = V1), | ||
alpha=0.4, size=3, stroke=0) + | ||
ggplot2::xlab("tree #") + | ||
ggplot2::ylab("Median absolute leaf weight") | ||
return(p) | ||
} | ||
} | ||
|
||
# Plot multiple ggplot graph aligned by rows and columns. | ||
# ... the plots | ||
# cols number of columns | ||
# internal utility function | ||
multiplot <- function(..., cols = 1) { | ||
plots <- list(...) | ||
num_plots = length(plots) | ||
|
||
layout <- matrix(seq(1, cols * ceiling(num_plots / cols)), | ||
ncol = cols, nrow = ceiling(num_plots / cols)) | ||
|
||
if (num_plots == 1) { | ||
print(plots[[1]]) | ||
} else { | ||
grid::grid.newpage() | ||
grid::pushViewport(grid::viewport(layout = grid::grid.layout(nrow(layout), ncol(layout)))) | ||
for (i in 1:num_plots) { | ||
# Get the i,j matrix positions of the regions that contain this subplot | ||
matchidx <- as.data.table(which(layout == i, arr.ind = TRUE)) | ||
|
||
print( | ||
plots[[i]], vp = grid::viewport( | ||
layout.pos.row = matchidx$row, | ||
layout.pos.col = matchidx$col | ||
) | ||
) | ||
} | ||
} | ||
} | ||
|
||
globalVariables(c( | ||
"Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme", | ||
"element_blank", "element_text" | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.