Skip to content

Commit

Permalink
Use parent forest seed in auxiliary forest (grf-labs#1070)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikcs authored Nov 25, 2021
1 parent 3b5f89b commit 0dd9537
Show file tree
Hide file tree
Showing 11 changed files with 14 additions and 2 deletions.
1 change: 1 addition & 0 deletions r-package/grf/R/causal_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ causal_forest <- function(X, Y, W,

forest <- do.call.rcpp(causal_train, c(data, args))
class(forest) <- c("causal_forest", "grf")
forest[["seed"]] <- seed
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/causal_survival_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ causal_survival_forest <- function(X, Y, W, D,

forest <- do.call.rcpp(causal_survival_train, c(data, args))
class(forest) <- c("causal_survival_forest", "grf")
forest[["seed"]] <- seed
forest[["_eta"]] <- eta
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
6 changes: 4 additions & 2 deletions r-package/grf/R/get_scores.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ get_scores.causal_forest <- function(forest,
clusters = clusters,
sample.weights = forest$sample.weights,
num.trees = num.trees.for.weights,
ci.group.size = 1)
ci.group.size = 1,
seed = forest$seed)
V.hat <- predict(variance_forest)$predictions
debiasing.weights.all <- (forest$W.orig - forest$W.hat) / V.hat
debiasing.weights <- debiasing.weights.all[subset]
Expand Down Expand Up @@ -176,7 +177,8 @@ get_scores.instrumental_forest <- function(forest,
W.hat = forest$Z.hat,
sample.weights = forest$sample.weights,
clusters = clusters,
num.trees = num.trees.for.weights)
num.trees = num.trees.for.weights,
seed = forest$seed)
compliance.score <- predict(compliance.forest)$predictions
compliance.score <- compliance.score[subset]
} else if (length(compliance.score) == length(forest$Y.orig)) {
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/instrumental_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ instrumental_forest <- function(X, Y, W, Z,

forest <- do.call.rcpp(instrumental_train, c(data, args))
class(forest) <- c("instrumental_forest", "grf")
forest[["seed"]] <- seed
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/ll_regression_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ ll_regression_forest <- function(X, Y,
}

class(forest) <- c("ll_regression_forest", "grf")
forest[["seed"]] <- seed
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/multi_arm_causal_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ multi_arm_causal_forest <- function(X, Y, W,

forest <- do.call.rcpp(multi_causal_train, c(data, args))
class(forest) <- c("multi_arm_causal_forest", "grf")
forest[["seed"]] <- seed
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/multi_regression_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ multi_regression_forest <- function(X, Y,

forest <- do.call.rcpp(multi_regression_train, c(data, args))
class(forest) <- c("multi_regression_forest", "grf")
forest[["seed"]] <- seed
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
forest[["sample.weights"]] <- sample.weights
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/probability_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ probability_forest <- function(X, Y,

forest <- do.call.rcpp(probability_train, c(data, args))
class(forest) <- c("probability_forest", "grf")
forest[["seed"]] <- seed
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
forest[["Y.relabeled"]] <- Y.relabeled
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/quantile_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ quantile_forest <- function(X, Y,

forest <- do.call.rcpp(quantile_train, c(data, args))
class(forest) <- c("quantile_forest", "grf")
forest[["seed"]] <- seed
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
forest[["quantiles.orig"]] <- quantiles
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/regression_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ regression_forest <- function(X, Y,

forest <- do.call.rcpp(regression_train, c(data, args))
class(forest) <- c("regression_forest", "grf")
forest[["seed"]] <- seed
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/survival_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ survival_forest <- function(X, Y, D,

forest <- do.call.rcpp(survival_train, c(data, args))
class(forest) <- c("survival_forest", "grf")
forest[["seed"]] <- seed
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
forest[["Y.relabeled"]] <- Y.relabeled
Expand Down

0 comments on commit 0dd9537

Please sign in to comment.