Skip to content

Commit

Permalink
Add a survival forest test (grf-labs#1059)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikcs authored Nov 9, 2021
1 parent c67516b commit e36d796
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions r-package/grf/tests/testthat/test_survival_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,23 @@ test_that("survival_forest works as expected with missing values", {
expect_equal(mse.oob.diff, 0, tolerance = 0.001)
expect_equal(mse.diff, 0, tolerance = 0.001)
})

test_that("survival forest with complete data is ~equal to regression forest", {
n <- 500
p <- 5
X <- matrix(runif(n * p), n, p)
Y.max <- 1
failure.time <- pmin(rexp(n) * X[, 1], Y.max)
censor.time <- 1e6 * runif(n)
Y <- pmin(failure.time, censor.time)
D <- as.integer(failure.time <= censor.time)

sf <- survival_forest(X, Y, D, num.trees = 500)
pp.sf <- predict(sf)
Y.hat.sf <- expected_survival(pp.sf$predictions, pp.sf$failure.times) #integral of the survival function

rf <- regression_forest(X, Y, num.trees = 500)
pp.rf <- predict(rf)$predictions

expect_equal(Y.hat.sf, pp.rf, tolerance = 0.075)
})

0 comments on commit e36d796

Please sign in to comment.