Skip to content

Commit

Permalink
Add lm_forest (grf-labs#1138)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikcs authored May 20, 2022
1 parent e80ed4f commit dbb6a23
Show file tree
Hide file tree
Showing 9 changed files with 697 additions and 2 deletions.
2 changes: 1 addition & 1 deletion REFERENCE.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ There are however empirical applications where there are more than one primary o

Another closely related practical application are settings where there are several interventions, as for example in medical trials with multiple treatment arms. In the event there are K mutually exclusive treatment choices, we can use the same algorithmic principles described above to build a forest that jointly targets heterogeneity across the K-1 different treatment contrasts. In the software package this is done with (20) from the GRF paper, where Wi is a vector encoded as {0, 1}^(K-1), and ξ selects the K-1 gradient approximations of the contrasts.

The functionality described above is available in `multi_arm_causal_forest`. The intended use-case for this function is for a "handful" of arms or outcomes. Statistical considerations aside (s.t. limited overlap with many arms), our implementation has a memory scaling that for each tree takes the form `O(number.of.nodes * M * K^2)`.
The functionality described above is available in `multi_arm_causal_forest`. The intended use-case for this function is for a "handful" of arms or outcomes. Statistical considerations aside (s.t. limited overlap with many arms), our implementation has a memory scaling that for each tree takes the form `O(number.of.nodes * M * K^2)`. The general case where all `Wk` are real-valued is implemented in the conditionally linear model forest: `lm_forest`.

### Right-Censored Survival Outcomes

Expand Down
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ jobs:
sudo Rscript -e "install.packages(c('fs', 'highlight', 'httr', 'memoise', 'openssl', 'purrr', 'rmarkdown', 'whisker', 'xml2', 'yaml'))"
sudo Rscript -e "install.packages('https://cran.r-project.org/src/contrib/Archive/pkgdown/pkgdown_1.5.1.tar.gz', repos = NULL, type = 'source')"
# Install dependencies needed for vignettes.
sudo Rscript -e "install.packages(c('DiagrammeR', 'ggplot2', 'glmnet', 'policytree'))"
sudo Rscript -e "install.packages(c('DiagrammeR', 'ggplot2', 'glmnet', 'policytree', 'rdd'))"
cp ../../README.md .
cp ../../REFERENCE.md .
cp ../../DEVELOPING.md .
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ RoxygenNote: 7.1.2
Suggests:
DiagrammeR,
MASS,
rdd,
survival (>= 3.2-8),
testthat (>= 3.0.4)
SystemRequirements: GNU make
Expand Down
2 changes: 2 additions & 0 deletions r-package/grf/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ S3method(predict,causal_forest)
S3method(predict,causal_survival_forest)
S3method(predict,instrumental_forest)
S3method(predict,ll_regression_forest)
S3method(predict,lm_forest)
S3method(predict,multi_arm_causal_forest)
S3method(predict,multi_regression_forest)
S3method(predict,probability_forest)
Expand Down Expand Up @@ -39,6 +40,7 @@ export(get_scores)
export(get_tree)
export(instrumental_forest)
export(ll_regression_forest)
export(lm_forest)
export(merge_forests)
export(multi_arm_causal_forest)
export(multi_regression_forest)
Expand Down
365 changes: 365 additions & 0 deletions r-package/grf/R/lm_forest.R

Large diffs are not rendered by default.

172 changes: 172 additions & 0 deletions r-package/grf/man/lm_forest.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

82 changes: 82 additions & 0 deletions r-package/grf/man/predict.lm_forest.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions r-package/grf/pkgdown/_pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ reference:
- instrumental_forest
- predict.instrumental_forest

- title: Linear model forest
contents:
- lm_forest
- predict.lm_forest

- title: Probability forest
contents:
- probability_forest
Expand Down
68 changes: 68 additions & 0 deletions r-package/grf/tests/testthat/test_lm_forest.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
test_that("lm_forest with single W ~ causal forest", {
# These tests are not done with an epsilon tolerance due to forests'
# discontinuous nature. Even though these two calls are in principle identical (with the same seed),
# some splits futher down the tree might deviate by chance due to minor numerical differences in
# implementation, thus leading to final point predictions that can differ more than an epsilon.
# For tests that locks in an equivalence between Causal Forest and its multivariate extension
# see `MultiCausalSplittingRuleTest.cpp`

# Binary W
n <- 1500
p <- 5
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 0.5)
Y <- pmax(X[, 1], 0) * W + X[, 2] + pmin(X[, 3], 0) + rnorm(n)
wts <- sample(1:2, n, TRUE)

Y.hat <- predict(regression_forest(X, Y, num.trees = 500, sample.weights = wts))$predictions
cf <- causal_forest(X, Y, W, Y.hat = Y.hat, W.hat = 0.5, sample.weights = wts, num.trees = 500, stabilize.splits = FALSE)
lmf <- lm_forest(X, Y, W, Y.hat = Y.hat, W.hat = 0.5, sample.weights = wts, num.trees = 500)
expect_lt(mean((predict(cf)$predictions - predict(lmf)$predictions[,,])^2), 0.03)
expect_equal(mean(predict(cf)$predictions), mean(predict(lmf)$predictions[,,]), tolerance = 0.03)

# Continuous W
W <- runif(n)
Y <- pmax(X[, 1], 0) * W + X[, 2] + pmin(X[, 3], 0) + rnorm(n)
Y.hat <- predict(regression_forest(X, Y, num.trees = 500, sample.weights = wts))$predictions
cfw <- causal_forest(X, Y, W, Y.hat = Y.hat, W.hat = 0.5, sample.weights = wts, num.trees = 500, stabilize.splits = FALSE)
lmfw <- lm_forest(X, Y, W, Y.hat = Y.hat, W.hat = 0.5, sample.weights = wts, num.trees = 500)
expect_lt(mean((predict(cfw)$predictions - predict(lmfw)$predictions[,,])^2), 0.05)
expect_equal(mean(predict(cfw)$predictions), mean(predict(lmfw)$predictions[,,]), tolerance = 0.05)
})

test_that("lm_forest with dummy W = multi arm causal forest", {
n <- 500
p <- 5
X <- matrix(rnorm(n * p), n, p)
W <- as.factor(sample(c("A", "B", "C"), n, replace = TRUE))
Y <- X[, 1] + 1.5 * (W == "A") + 2.8 * (W == "B") - 4 * (W == "C") + rnorm(n)
wts <- sample(1:2, n, TRUE)

Y.hat <- predict(multi_regression_forest(X, Y, num.trees = 250))$predictions
W.hat <- predict(probability_forest(X, W, num.trees = 250))$predictions

W.matrix <- model.matrix(~ W - 1)

mcf <- multi_arm_causal_forest(X, Y, W, Y.hat = Y.hat, W.hat = W.hat, num.trees = 250, sample.weights = wts, seed = 42, stabilize.splits = FALSE)
lmf <- lm_forest(X, Y, W.matrix[, -1], Y.hat = Y.hat, W.hat = W.hat[, -1], num.trees = 250, sample.weights = wts, seed = 42)

expect_equal(unname(predict(lmf)$predictions), unname(predict(mcf)$predictions))
})

test_that("lm_forest gradient.weights option works as expected", {
n <- 250
p <- 5
K <- 2
X <- matrix(rnorm(n * p), n, p)
W <- matrix(runif(n * K), n, K)
Y <- X[, 1] - W[, 1] * pmax(X[, 2], 0) + W[, 2] + rnorm(n)

lmf <- lm_forest(X, Y, W, num.trees = 250, gradient.weights = c(0.5, 1), seed = 42)
lmf2 <- lm_forest(X, cbind(Y, Y), W, num.trees = 250, gradient.weights = c(0.5, 1), seed = 42)
expect_equal(predict(lmf)$predictions[,,], predict(lmf2)$predictions[,, 1])
expect_equal(predict(lmf)$predictions[,,], predict(lmf2)$predictions[,, 2])

lmf3 <- lm_forest(X, Y, W, num.trees = 250, seed = 42)
lmf4 <- lm_forest(X, Y, W, num.trees = 250, gradient.weights = c(0.5, 0.5), seed = 42)
expect_equal(predict(lmf3)$predictions, predict(lmf4)$predictions)
})

0 comments on commit dbb6a23

Please sign in to comment.