Skip to content

Commit

Permalink
test and fix imbs-hl#626
Browse files Browse the repository at this point in the history
  • Loading branch information
mnwright committed Jul 20, 2022
1 parent 6b7febd commit 6e33850
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 3 deletions.
6 changes: 3 additions & 3 deletions R/treeInfo.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,16 @@ treeInfo <- function(object, tree = 1) {
result$prediction <- forest$split.values[[tree]]
result$prediction[!result$terminal] <- NA
if (!is.null(forest$levels)) {
result$prediction <- factor(result$prediction, levels = forest$class.values, labels = forest$levels)
result$prediction <- integer.to.factor(result$prediction, labels = forest$levels)
}
} else if (forest$treetype == "Regression") {
result$prediction <- forest$split.values[[tree]]
result$prediction[!result$terminal] <- NA
} else if (forest$treetype == "Probability estimation") {
predictions <- matrix(nrow = nrow(result), ncol = length(forest$levels))
predictions <- matrix(nrow = nrow(result), ncol = length(forest$class.values))
predictions[result$terminal, ] <- do.call(rbind, forest$terminal.class.counts[[tree]])
colnames(predictions) <- forest$levels[forest$class.values]
predictions <- predictions[, forest$levels, drop = FALSE]
predictions <- predictions[, forest$levels[sort(forest$class.values)], drop = FALSE]
colnames(predictions) <- paste0("pred.", colnames(predictions))
result <- data.frame(result, predictions)
} else if (forest$treetype == "Survival") {
Expand Down
84 changes: 84 additions & 0 deletions tests/testthat/test_treeInfo.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,90 @@ test_that("Prediction for classification is factor with correct levels", {
expect_equal(levels(ti.class.formula$prediction), levels(iris$Species))
})

test_that("Prediction for classification is same as class prediction", {
dat <- iris[sample(nrow(iris)), ]
rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
replace = FALSE, sample.fraction = 1)
pred_class <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
ti <- treeInfo(rf, 1)
pred_ti <- sapply(nodes, function(x) {
ti[ti$nodeID == x, "prediction"]
})
expect_equal(pred_ti, pred_class)
})

test_that("Prediction for classification is same as class prediction, new factor", {
dat <- iris[sample(nrow(iris)), ]
dat$Species <- factor(dat$Species, levels = sample(levels(dat$Species)))
rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
replace = FALSE, sample.fraction = 1)
pred_class <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
ti <- treeInfo(rf, 1)
pred_ti <- sapply(nodes, function(x) {
ti[ti$nodeID == x, "prediction"]
})
expect_equal(pred_ti, pred_class)
})

test_that("Prediction for classification is same as class prediction, unused factor levels", {
dat <- iris[c(101:150, 51:100), ]
expect_warning(rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
replace = FALSE, sample.fraction = 1))
pred_class <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
ti <- treeInfo(rf, 1)
pred_ti <- sapply(nodes, function(x) {
ti[ti$nodeID == x, "prediction"]
})
expect_equal(pred_ti, pred_class)
})

test_that("Prediction for probability is same as probability prediction", {
dat <- iris[sample(nrow(iris)), ]
rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
sample.fraction = 1, replace = FALSE, probability = TRUE)
ti <- treeInfo(rf)
pred_prob <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
pred_ti <- t(sapply(nodes, function(x) {
as.matrix(ti[ti$nodeID == x, 8:10])
}))
colnames(pred_ti) <- gsub("pred\\.", "", colnames(ti)[8:10])
expect_equal(pred_prob, pred_ti)
})

test_that("Prediction for probability is same as probability prediction, new factor", {
dat <- iris[sample(nrow(iris)), ]
dat$Species <- factor(dat$Species, levels = sample(levels(dat$Species)))
rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
sample.fraction = 1, replace = FALSE, probability = TRUE)
ti <- treeInfo(rf)
pred_prob <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
pred_ti <- t(sapply(nodes, function(x) {
as.matrix(ti[ti$nodeID == x, 8:10])
}))
colnames(pred_ti) <- gsub("pred\\.", "", colnames(ti)[8:10])
expect_equal(pred_prob, pred_ti)
})

test_that("Prediction for probability is same as probability prediction, unused factor levels", {
dat <- iris[c(101:150, 51:100), ]
dat$Species <- factor(dat$Species, levels = sample(levels(dat$Species)))
expect_warning(rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
sample.fraction = 1, replace = FALSE, probability = TRUE))
ti <- treeInfo(rf)
pred_prob <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
pred_ti <- t(sapply(nodes, function(x) {
as.matrix(ti[ti$nodeID == x, 8:9])
}))
colnames(pred_ti) <- gsub("pred\\.", "", colnames(ti)[8:9])
expect_equal(pred_prob, pred_ti)
})

test_that("Prediction for matrix classification is integer with correct values", {
rf <- ranger(dependent.variable.name = "Species", data = data.matrix(iris),
num.trees = 5, classification = TRUE)
Expand Down

0 comments on commit 6e33850

Please sign in to comment.