Skip to content

Commit

Permalink
Merge pull request imbs-hl#391 from imbs-hl/issue390
Browse files Browse the repository at this point in the history
Fix IJ se estimation with probability for single obs (imbs-hl#390)
  • Loading branch information
mnwright authored Mar 8, 2019
2 parents 7b894bc + f06f1f9 commit f57c5d9
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 3 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.11.2
Date: 2019-03-07
Date: 2019-03-08
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand Down
11 changes: 10 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,10 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
result$predictions <- array(result$predictions, dim = c(1, length(result$predictions)))
}
} else {
## TODO: Better solution for this?
if (is.list(result$predictions) & length(result$predictions) >= 1 & is.numeric(result$predictions[[1]])) {
# Fix for single test observation
result$predictions <- list(result$predictions)
}
result$predictions <- aperm(array(unlist(result$predictions),
dim = rev(c(length(result$predictions),
length(result$predictions[[1]]),
Expand Down Expand Up @@ -441,6 +444,12 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
## Set colnames and sort by levels
colnames(result$predictions) <- forest$levels[forest$class.values]
result$predictions <- result$predictions[, forest$levels, drop = FALSE]

if (!is.matrix(result$se)) {
result$se <- matrix(result$se, ncol = length(forest$levels))
}
colnames(result$se) <- forest$levels[forest$class.values]
result$se <- result$se[, forest$levels, drop = FALSE]
}
}

Expand Down
26 changes: 26 additions & 0 deletions tests/testthat/test_jackknife.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,29 @@ test_that("No error for se estimation for many observations", {
rf <- ranger(y ~ x, dat, num.trees = 2, keep.inbag = TRUE)
expect_silent(predict(rf, dat, type = "se", se.method = "infjack"))
})

test_that("Standard error prediction working for single observation, regression", {
test <- iris[1, , drop = FALSE]
train <- iris[-1, ]

rf <- ranger(Petal.Length ~ ., train, num.trees = 5, keep.inbag = TRUE)

# Jackknife
pred <- predict(rf, test, type = "se", se.method = "jack")
expect_length(pred$se, 1)

# IJ
pred <- expect_warning(predict(rf, test, type = "se", se.method = "infjack"))
expect_length(pred$se, 1)
})

test_that("Standard error prediction working for single observation, probability", {
test <- iris[134, , drop = FALSE]
train <- iris[-134, ]

rf <- ranger(Species ~ ., train, num.trees = 5, keep.inbag = TRUE, probability = TRUE)

# IJ
pred <- expect_warning(predict(rf, test, type = "se", se.method = "infjack"))
expect_equal(dim(pred$se), c(1, 3))
})
2 changes: 1 addition & 1 deletion tests/testthat/test_maxstat.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ test_that("maxstat impurity importance is positive", {
})

test_that("maxstat corrected impurity importance is positive (on average)", {
rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5,
rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 50,
splitrule = "maxstat", importance = "impurity_corrected")
expect_gt(mean(rf$variable.importance), 0)

Expand Down

0 comments on commit f57c5d9

Please sign in to comment.