forked from MCKnaus/dmlmt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Example.R
92 lines (71 loc) · 2.96 KB
/
Example.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Download current version from Github
library(devtools)
install_github(repo="MCKnaus/dmlmt")
library(dmlmt)
# Get data
library(hdm)
data(pension)
Y = pension$tw; D = pension$p401
# Only main effects (toy example)
X = model.matrix(~ -1 + i2 + i3 + i4 + i5 + i6 + i7 + a2 + a3 + a4 + a5 +
fsize + hs + smcol + col + marr + twoearn + db + pira + hown, data = pension)
## Consider also interactions if you have some time
# X = model.matrix(~ -1 + (i2 + i3 + i4 + i5 + i6 + i7 + a2 + a3 + a4 + a5 +
# fsize + hs + smcol + col + marr + twoearn + db + pira + hown)^2, data = pension)
###########################
#### Standard analysis ####
###########################
# Post-Lasso for nuisance parameters (might be slow)
stand_pl_bin <- dmlmt(X,D,Y)
# Lasso for nuisance parameters (faster)
stand_l_bin <- dmlmt(X,D,Y,pl=FALSE)
# Create multiple treatment by splitting control group
D_mult <- D
D_mult[runif(length(D))*(1-D)>0.5] <- 2
table(D_mult)
# Run the analysis with multiple treatments
stand_pl_mult <- dmlmt(X,D_mult,Y)
stand_l_mult <- dmlmt(X,D_mult,Y,pl=FALSE)
#################################
### Extended Analysis ###
#################################
# Consider different rules to select penalty term and calculate implied weights
se_rules <- c(-1,-.5,.5,1)
# Binary
ext_pl_bin <- dmlmt(X,D,Y,se_rule=se_rules,w=TRUE)
# Example how to plot the results
library(ggplot2)
df <- data.frame(SE_rule = factor(colnames(ext_pl_bin$SE_rule[[1]])
,levels = colnames(ext_pl_bin$SE_rule[[1]]))
,coef = ext_pl_bin$SE_rule[[1]][1,],se = ext_pl_bin$SE_rule[[2]][1,])
j <- ggplot(df, aes(SE_rule, coef, ymin = coef-se, ymax = coef+se)) +
geom_errorbar() + geom_point()
# Example how to check balancing with the package of your choice, e.g. cobalt
library(cobalt)
balance <- bal.tab(as.data.frame(X), treat = D,weights=ext_pl_bin$weights,method = "weighting",
s.d.denom = "pooled", disp.v.ratio = TRUE, disp.ks = TRUE, un = TRUE)
plot <- love.plot(balance,abs = TRUE, line=TRUE, var.order="unadjusted")
# Multiple
ext_pl_mult <- dmlmt(X,D_mult,Y,se_rule=se_rules,w=TRUE)
#################################
### Generic ML input ###
#################################
# Example with Random Forest instead of Lasso
library(grf)
# Initialize nuisance matrices
values <- sort(unique(D_mult))
ps_mat <- t_mat <- y_mat <- matrix(NA,length(Y),length(values))
# Get nuisance parameter predictions
for (tr in 1:length(values)){
t_mat[,tr] <- as.numeric(D_mult == values[tr])
rf_p <- regression_forest(X,t_mat[,tr])
ps_mat[,tr] <- predict(rf_p, X)$predictions
rf_y <- regression_forest(X[t_mat[,tr] == 1,],Y[t_mat[,tr] == 1])
y_mat[,tr] <- predict(rf_y, X)$predictions
}
# Calculate generalized p-score and enforce common support
rf_gps <- gps_cs(ps_mat,t_mat)
# Potential outcomes
rf_PO <- PO_dmlmt(t_mat,Y,y_mat,rf_gps$p,cs_i=rf_gps$cs)
# ATE
rf_ATE <- TE_dmlmt(rf_PO$mu,rf_gps$cs)