Skip to content

Commit

Permalink
Update CSF survival probability relabeling (grf-labs#1080)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikcs authored Dec 6, 2021
1 parent 807815c commit 02d9012
Show file tree
Hide file tree
Showing 5 changed files with 758 additions and 762 deletions.
2 changes: 2 additions & 0 deletions experiments/csf/estimators.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ IPCW = function(data, data.test) {
sample.weights = 1 / C.Y.hat
subset = data$D == 1
cf = causal_forest(data$X[subset, ], data$Y[subset], data$W[subset], sample.weights = sample.weights[subset])

subset = data$D == 1 | data$Y > data$y0
horizonC.index = findInterval(data$y0, sf.censor$failure.times)
if (horizonC.index != 0) {
C.Y.hat[data$Y > data$y0] = C.hat[data$Y > data$y0, horizonC.index]
Expand Down
13 changes: 6 additions & 7 deletions r-package/grf/R/causal_survival_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ causal_survival_forest <- function(X, Y, W, D,
if (length(Y.grid) <= 2) {
stop("The number of distinct event times should be more than 2.")
}
if (horizon < min(Y.grid)) {
stop("`horizon` cannot be before the first event.")
}
if (nrow(X) > 5000 && length(Y.grid) / nrow(X) > 0.1) {
warning(paste0("The number of events are more than 10% of the sample size. ",
"To reduce the computational burden of fitting survival and ",
Expand Down Expand Up @@ -266,13 +269,9 @@ causal_survival_forest <- function(X, Y, W, D,
sf.censor <- do.call(survival_forest, c(list(X = cbind(X, W), Y = Y, D = 1 - D), args.nuisance))
C.hat <- predict(sf.censor, failure.times = Y.grid)$predictions
if (target == "survival.probability") {
# P(Ci > min(Yi, horizon) | Xi, Wi)
horizonC.index <- findInterval(horizon, Y.grid)
if (horizonC.index == 0) {
C.hat[] <- 1
} else {
C.hat[, horizonC.index:ncol(C.hat)] <- C.hat[, horizonC.index]
}
# Evaluate psi up to horizon
D[Y > horizon] <- 1
Y[Y > horizon] <- horizon
}

Y.index <- findInterval(Y, Y.grid) # (invariance: Y.index > 0)
Expand Down
Loading

0 comments on commit 02d9012

Please sign in to comment.