Skip to content

Commit

Permalink
[R-package] GPL2 dependency reduction and some fixes (dmlc#1401)
Browse files Browse the repository at this point in the history
* [R] do not remove zero coefficients from gblinear dump

* [R] switch from stringr to stringi

* fix dmlc#1399

* [R] separate ggplot backend, add base r graphics, cleanup, more plots, tests

* add missing include in amalgamation - fixes building R package in linux

* add forgotten file

* [R] fix DESCRIPTION

* [R] fix travis check issue and some cleanup
  • Loading branch information
khotilov authored and hetong007 committed Jul 27, 2016
1 parent f642305 commit d5c1433
Show file tree
Hide file tree
Showing 19 changed files with 541 additions and 305 deletions.
2 changes: 1 addition & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ Imports:
methods,
data.table (>= 1.9.6),
magrittr (>= 1.5),
stringr (>= 0.6.2)
stringi (>= 0.5.2)
RoxygenNote: 5.0.1
15 changes: 9 additions & 6 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ export(xgb.attributes)
export(xgb.create.features)
export(xgb.cv)
export(xgb.dump)
export(xgb.ggplot.deepness)
export(xgb.ggplot.importance)
export(xgb.importance)
export(xgb.load)
export(xgb.model.dt.tree)
Expand All @@ -53,15 +55,16 @@ importFrom(data.table,":=")
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)
importFrom(data.table,rbindlist)
importFrom(data.table,setkey)
importFrom(data.table,setkeyv)
importFrom(data.table,setnames)
importFrom(magrittr,"%>%")
importFrom(stats,predict)
importFrom(stringr,str_detect)
importFrom(stringr,str_extract)
importFrom(stringr,str_match)
importFrom(stringr,str_replace)
importFrom(stringr,str_replace_all)
importFrom(stringr,str_split)
importFrom(stringi,stri_detect_regex)
importFrom(stringi,stri_match_first_regex)
importFrom(stringi,stri_replace_all_regex)
importFrom(stringi,stri_replace_first_regex)
importFrom(stringi,stri_split_regex)
importFrom(utils,object.size)
importFrom(utils,str)
importFrom(utils,tail)
Expand Down
9 changes: 6 additions & 3 deletions R-package/R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,12 @@ cb.cv.predict <- function(save_models = FALSE) {
stop("'cb.cv.predict' callback requires 'basket' and 'bst_folds' lists in its calling frame")

N <- nrow(env$data)
pred <- ifelse(env$num_class > 1,
matrix(NA_real_, N, env$num_class),
rep(NA_real_, N))
pred <-
if (env$num_class > 1) {
matrix(NA_real_, N, env$num_class)
} else {
rep(NA_real_, N)
}

ntreelimit <- NVL(env$basket$best_ntreelimit,
env$end_iteration * env$num_parallel_tree)
Expand Down
2 changes: 1 addition & 1 deletion R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
if (is.null(feval)) {
msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist,
as.list(evnames), PACKAGE = "xgboost")
msg <- str_split(msg, '(\\s+|:|\\s+)')[[1]][-1]
msg <- stri_split_regex(msg, '(\\s+|:|\\s+)')[[1]][-1]
res <- as.numeric(msg[c(FALSE,TRUE)]) # even indices are the values
names(res) <- msg[c(TRUE,FALSE)] # odds are the names
} else {
Expand Down
6 changes: 3 additions & 3 deletions R-package/R/xgb.dump.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with_stats=FALSE, ..
model_dump <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with_stats), PACKAGE = "xgboost")

if (is.null(fname))
model_dump <- str_replace_all(model_dump, '\t', '')
model_dump <- stri_replace_all_regex(model_dump, '\t', '')

model_dump <- unlist(str_split(model_dump, '\n'))
model_dump <- grep('(^$|^0$)', model_dump, invert = TRUE, value = TRUE)
model_dump <- unlist(stri_split_regex(model_dump, '\n'))
model_dump <- grep('^\\s*$', model_dump, invert = TRUE, value = TRUE)

if (is.null(fname)) {
return(model_dump)
Expand Down
135 changes: 135 additions & 0 deletions R-package/R/xgb.ggplot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# ggplot backend for the xgboost plotting facilities


#' @rdname xgb.plot.importance
#' @export
xgb.ggplot.importance <- function(importance_matrix = NULL, top_n = NULL, measure = NULL,
rel_to_first = FALSE, n_clusters = c(1:10), ...) {

importance_matrix <- xgb.plot.importance(importance_matrix, top_n = top_n, measure = measure,
rel_to_first = rel_to_first, plot = FALSE, ...)
if (!requireNamespace("ggplot2", quietly = TRUE)) {
stop("ggplot2 package is required", call. = FALSE)
}
if (!requireNamespace("Ckmeans.1d.dp", quietly = TRUE)) {
stop("Ckmeans.1d.dp package is required", call. = FALSE)
}

clusters <- suppressWarnings(
Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix$Importance, n_clusters)
)
importance_matrix[, Cluster := as.character(clusters$cluster)]

plot <-
ggplot2::ggplot(importance_matrix,
ggplot2::aes(x = factor(Feature, levels = rev(Feature)), y = Importance, width = 0.05),
environment = environment()) +
ggplot2::geom_bar(ggplot2::aes(fill = Cluster), stat = "identity", position = "identity") +
ggplot2::coord_flip() +
ggplot2::xlab("Features") +
ggplot2::ggtitle("Feature importance") +
ggplot2::theme(plot.title = ggplot2::element_text(lineheight = .9, face = "bold"),
panel.grid.major.y = ggplot2::element_blank())
return(plot)
}


#' @rdname xgb.plot.deepness
#' @export
xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight")) {

if (!requireNamespace("ggplot2", quietly = TRUE))
stop("ggplot2 package is required for plotting the graph deepness.", call. = FALSE)

which <- match.arg(which)

dt_depths <- xgb.plot.deepness(model = model, plot = FALSE)
dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth]
setkey(dt_summaries, 'Depth')

if (which == "2x1") {
p1 <-
ggplot2::ggplot(dt_summaries) +
ggplot2::geom_bar(ggplot2::aes(x = Depth, y = N), stat = "Identity") +
ggplot2::xlab("") +
ggplot2::ylab("Number of leafs") +
ggplot2::ggtitle("Model complexity") +
ggplot2::theme(
plot.title = ggplot2::element_text(lineheight = 0.9, face = "bold"),
panel.grid.major.y = ggplot2::element_blank(),
axis.ticks = ggplot2::element_blank(),
axis.text.x = ggplot2::element_blank()
)

p2 <-
ggplot2::ggplot(dt_summaries) +
ggplot2::geom_bar(ggplot2::aes(x = Depth, y = Cover), stat = "Identity") +
ggplot2::xlab("Leaf depth") +
ggplot2::ylab("Weighted cover")

multiplot(p1, p2, cols = 1)
return(invisible(list(p1, p2)))

} else if (which == "max.depth") {
p <-
ggplot2::ggplot(dt_depths[, max(Depth), Tree]) +
ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
height = 0.15, alpha=0.4, size=3, stroke=0) +
ggplot2::xlab("tree #") +
ggplot2::ylab("Max tree leaf depth")
return(p)

} else if (which == "med.depth") {
p <-
ggplot2::ggplot(dt_depths[, median(as.numeric(Depth)), Tree]) +
ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
height = 0.15, alpha=0.4, size=3, stroke=0) +
ggplot2::xlab("tree #") +
ggplot2::ylab("Median tree leaf depth")
return(p)

} else if (which == "med.weight") {
p <-
ggplot2::ggplot(dt_depths[, median(abs(Weight)), Tree]) +
ggplot2::geom_point(ggplot2::aes(x = Tree, y = V1),
alpha=0.4, size=3, stroke=0) +
ggplot2::xlab("tree #") +
ggplot2::ylab("Median absolute leaf weight")
return(p)
}
}

# Plot multiple ggplot graph aligned by rows and columns.
# ... the plots
# cols number of columns
# internal utility function
multiplot <- function(..., cols = 1) {
plots <- list(...)
num_plots = length(plots)

layout <- matrix(seq(1, cols * ceiling(num_plots / cols)),
ncol = cols, nrow = ceiling(num_plots / cols))

if (num_plots == 1) {
print(plots[[1]])
} else {
grid::grid.newpage()
grid::pushViewport(grid::viewport(layout = grid::grid.layout(nrow(layout), ncol(layout))))
for (i in 1:num_plots) {
# Get the i,j matrix positions of the regions that contain this subplot
matchidx <- as.data.table(which(layout == i, arr.ind = TRUE))

print(
plots[[i]], vp = grid::viewport(
layout.pos.row = matchidx$row,
layout.pos.col = matchidx$col
)
)
}
}
}

globalVariables(c(
"Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme",
"element_blank", "element_text"
))
10 changes: 5 additions & 5 deletions R-package/R/xgb.model.dt.tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
text <- xgb.dump(model = model, with_stats = T)
}

position <- which(!is.na(str_match(text, "booster")))
position <- which(!is.na(stri_match_first_regex(text, "booster")))

add.tree.id <- function(x, i) paste(i, x, sep = "-")

Expand All @@ -82,16 +82,16 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
n_first_tree <- min(max(td$Tree), n_first_tree)
td <- td[Tree <= n_first_tree & !grepl('^booster', t)]

td[, Node := str_match(t, "(\\d+):")[,2] %>% as.numeric ]
td[, Node := stri_match_first_regex(t, "(\\d+):")[,2] %>% as.numeric ]
td[, ID := add.tree.id(Node, Tree)]
td[, isLeaf := !is.na(str_match(t, "leaf"))]
td[, isLeaf := !is.na(stri_match_first_regex(t, "leaf"))]

# parse branch lines
td[isLeaf==FALSE, c("Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover") := {
rx <- paste0("f(\\d+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),",
"gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
# skip some indices with spurious capture groups from anynumber_regex
xtr <- str_match(t, rx)[, c(2,3,5,6,7,8,10)]
xtr <- stri_match_first_regex(t, rx)[, c(2,3,5,6,7,8,10)]
xtr[, 3:5] <- add.tree.id(xtr[, 3:5], Tree)
lapply(1:ncol(xtr), function(i) xtr[,i])
}]
Expand All @@ -102,7 +102,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
# parse leaf lines
td[isLeaf==TRUE, c("Feature", "Quality", "Cover") := {
rx <- paste0("leaf=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
xtr <- str_match(t, rx)[, c(2,4)]
xtr <- stri_match_first_regex(t, rx)[, c(2,4)]
c("Leaf", lapply(1:ncol(xtr), function(i) xtr[,i]))
}]

Expand Down
Loading

0 comments on commit d5c1433

Please sign in to comment.