Skip to content

Commit ba9f993

Browse files
authored
Add IJMPR replication (#1436)
1 parent 83e2f9f commit ba9f993

File tree

5 files changed

+4292
-0
lines changed

5 files changed

+4292
-0
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ Imke Mayer, Erik Sverdrup, Tobias Gauss, Jean-Denis Moyer, Stefan Wager and Juli
161161
[<a href="https://projecteuclid.org/euclid.aoas/1600454872">paper</a>,
162162
<a href="https://arxiv.org/pdf/1910.10624.pdf">arxiv</a>]
163163

164+
Erik Sverdrup, Maria Petukhova, and Stefan Wager.
165+
<b>Estimating Treatment Effect Heterogeneity in Psychiatry: A Review and Tutorial with Causal Forests.</b> 2024.
166+
[<a href="https://arxiv.org/abs/2409.01578">arxiv</a>]
167+
164168
Stefan Wager and Susan Athey.
165169
<b>Estimation and Inference of Heterogeneous Treatment Effects using Random Forests.</b>
166170
<i>Journal of the American Statistical Association</i>, 113(523), 2018.

experiments/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ This directory contains replication code for
1212

1313
* Mayer, Sverdrup, Gauss, Moyer, Wager, and Josse (2020): This is available at https://github.com/imkemayer/causal-inference-missing
1414

15+
* Sverdrup, Petukhova, and Wager (2024): `ijmpr`
16+
1517
* Wager and Athey (2018): This paper is not based on GRF, but on the deprecated `causalForest`. For replication code see https://github.com/swager/causalForest
1618

1719
* Yadlowsky, Fleming, Shah, Brunskill, and Wager (2021): The method is available in the GRF function `rank_average_treatment_effect`. For replication code see https://github.com/som-shahlab/RATE-experiments
@@ -46,6 +48,10 @@ Imke Mayer, Erik Sverdrup, Tobias Gauss, Jean-Denis Moyer, Stefan Wager and Juli
4648
[<a href="https://projecteuclid.org/euclid.aoas/1600454872">paper</a>,
4749
<a href="https://arxiv.org/pdf/1910.10624.pdf">arxiv</a>]
4850

51+
Erik Sverdrup, Maria Petukhova, and Stefan Wager.
52+
<b>Estimating Treatment Effect Heterogeneity in Psychiatry: A Review and Tutorial with Causal Forests.</b> 2024.
53+
[<a href="https://arxiv.org/abs/2409.01578">arxiv</a>]
54+
4955
Stefan Wager and Susan Athey.
5056
<b>Estimation and Inference of Heterogeneous Treatment Effects using Random Forests.</b>
5157
<i>Journal of the American Statistical Association</i>, 113(523), 2018.

experiments/ijmpr/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
_This folder has replication files for the paper "Estimating Treatment Effect Heterogeneity in Psychiatry: A Review and Tutorial with Causal Forests" by Sverdrup, Petukhova, and Wager._
2+
3+
The file `analysis.R` contains example code to run the type of analysis described in the paper using synthetic data. This script relies on the packages `"grf", "maq", "ggplot2"`.
4+
5+
The file `synthetic_data.csv` was generated using the following code and is not intended to bear resemblance to the Army STARRS-LS data used in the paper.
6+
7+
```R
8+
n = 4000
9+
X = cbind(round(matrix(rnorm(n * 5), n, 5), 2), matrix(rbinom(n * 4, 1, 0.5), n, 4))
10+
W = rbinom(n, 1, 1 / (1 + exp(-X[, 1] - X[, 2])))
11+
Y = 1 - rbinom(n, 1, 1 / (1 + exp((pmax(2 * X[, 1], 0) * W + 1))))
12+
colnames(X) = make.names(1:ncol(X))
13+
write.csv(cbind(outcome = Y, treatment = W, X), "synthetic_data.csv", row.names = FALSE)
14+
```

experiments/ijmpr/analysis.R

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
rm(list = ls())
2+
set.seed(42)
3+
4+
library(grf)
5+
library(maq) # For Qini curves.
6+
library(ggplot2)
7+
8+
# Read in data and specify outcome Y, treatment W, and (numeric) matrix of covariates X.
9+
data = read.csv("https://raw.githubusercontent.com/grf-labs/grf/master/experiments/ijmpr/synthetic_data.csv")
10+
Y = data$outcome
11+
W = data$treatment
12+
X = data[, -c(1, 2)]
13+
14+
15+
# This script assumes the covariates X have named columns.
16+
# If not provided, we make up some default names.
17+
if (is.null(colnames(X))) colnames(X) = make.names(1:ncol(X))
18+
19+
20+
# *** Estimating an average treatment effect (ATE) ***
21+
22+
# A simple difference in means estimate ignoring non-random assignment.
23+
summary(lm(Y ~ W))
24+
25+
# A doubly robust ATE estimate (forest-based AIPW).
26+
cf.full = causal_forest(X, Y, W)
27+
average_treatment_effect(cf.full)
28+
29+
# A histogram of the estimated propensity scores.
30+
# Overlap requires that these don't get too close to either 0 or 1.
31+
hist(cf.full$W.hat, xlab = "Estimated propensity scores", main = "")
32+
33+
34+
# *** Estimating CATEs ***
35+
36+
# Split data into a train and test sample.
37+
train = sample(nrow(X), 0.6 * nrow(X))
38+
test = -train
39+
40+
# Fit a CATE function on training data.
41+
cate.forest = causal_forest(X[train, ], Y[train], W[train])
42+
43+
# Predict CATEs on test set.
44+
X.test = X[test, ]
45+
tau.hat.test = predict(cate.forest, X.test)$predictions
46+
47+
# A histogram of CATE estimates.
48+
hist(tau.hat.test, xlab = "Estimated CATEs", main = "")
49+
50+
# On their own, the CATE point estimates are noisy.
51+
# What we often care about is whether they capture meaningful heterogeneity.
52+
53+
# Here we construct groups according to which quartile of the predicted CATEs the unit belongs.
54+
# Then, we calculate ATEs in each of these groups and see if they differ.
55+
num.groups = 4 # 4 for quartiles, 5 for quintiles, etc.
56+
quartile = cut(tau.hat.test,
57+
quantile(tau.hat.test, seq(0, 1, by = 1 / num.groups)),
58+
labels = 1:num.groups,
59+
include.lowest = TRUE)
60+
# Create a list of test set samples by CATE quartile.
61+
samples.by.quartile = split(seq_along(quartile), quartile)
62+
63+
# Look at ATEs in each of these quartiles. To calculate these we fit a separate evaluation forest.
64+
eval.forest = causal_forest(X.test, Y[test], W[test])
65+
66+
# Calculate doubly robust ATEs for each group.
67+
ate.by.quartile = lapply(samples.by.quartile, function(samples) {
68+
average_treatment_effect(eval.forest, subset = samples)
69+
})
70+
71+
# Plot group ATEs along with 95% confidence bars.
72+
df.plot.ate = data.frame(
73+
matrix(unlist(ate.by.quartile), num.groups, byrow = TRUE, dimnames = list(NULL, c("estimate","std.err"))),
74+
group = 1:num.groups
75+
)
76+
77+
ggplot(df.plot.ate, aes(x = group, y = estimate)) +
78+
geom_point() +
79+
geom_errorbar(aes(ymin = estimate - 1.96 * std.err, ymax = estimate + 1.96 * std.err, width = 0.2)) +
80+
xlab("Estimated CATE quantile") +
81+
ylab("Average treatment effect")
82+
83+
84+
# *** Evaluate heterogeneity via TOC/AUTOC ***
85+
86+
# In the previous section we split the data into quartiles.
87+
# The TOC is essentially a continuous version of this exercise.
88+
89+
# Use eval.forest to form a doubly robust estimate of TOC/AUTOC.
90+
rate.cate = rank_average_treatment_effect(
91+
eval.forest,
92+
tau.hat.test,
93+
q = seq(0.05, 1, length.out = 100)
94+
)
95+
# Plot the TOC.
96+
plot(rate.cate)
97+
98+
# An estimate and standard error of AUTOC.
99+
print(rate.cate)
100+
101+
# Get a 2-sided p-value Pr(>|t|) for RATE = 0 using a t-value.
102+
2 * pnorm(-abs(rate.cate$estimate / rate.cate$std.err))
103+
104+
# [For alternatives to perform these tests without performing train/test splits,
105+
# including relying on one-sided tests, see the following vignette for more details
106+
# https://grf-labs.github.io/grf/articles/rate_cv.html]
107+
108+
109+
# *** Policy evaluation with Qini curves ****
110+
111+
# We can use the `maq` package for this exercise. This package is more general
112+
# and accepts CATE estimates from multiple treatment arms along with costs that
113+
# denominate what we spend by assigning a unit a treatment. In this application
114+
# we can simply treat the number of units we are considering deploying as the cost.
115+
116+
# Form a doubly robust estimate of a CATE-based Qini curve (using eval.forest).
117+
num.units = nrow(X)
118+
qini = maq(tau.hat.test,
119+
num.units,
120+
get_scores(eval.forest) * num.units,
121+
R = 200)
122+
123+
# Form a baseline Qini curve that assigns treatment uniformly.
124+
qini.baseline = maq(tau.hat.test,
125+
num.units,
126+
get_scores(eval.forest) * num.units,
127+
R = 200,
128+
target.with.covariates = FALSE)
129+
130+
# Plot the Qini curve along with 95% confidence lines.
131+
plot(qini, ylab = "PTSD cases prevented", xlab = "Units held back from deployment", xlim = c(0, num.units))
132+
plot(qini.baseline, add = TRUE, ci.args = NULL)
133+
134+
# Get estimates from the curve, at for example 500 deployed units.
135+
average_gain(qini, 500)
136+
137+
# Compare the benefit of targeting the 500 units predicted to benefit the most with the baseline.
138+
difference_gain(qini, qini.baseline, 500)
139+
140+
141+
# [The paper shows Qini curves embellished with ggplot. We could have retrieved
142+
# the data underlying the curves and customized our plots further.
143+
# For more details we refer to https://github.com/grf-labs/maq]
144+
145+
146+
# *** Describing the fit CATE function ****
147+
148+
# Our `cate.forest` has given us some estimated function \tau(x).
149+
# Let's have a closer look at how this function stratifies our sample in terms of "covariate" profiles.
150+
# One way to do so is to look at histograms of our covariates by for example low / high CATE predictions.
151+
152+
# First, we'll use a simple heuristic to narrow down the number of predictors to look closer at.
153+
# Here we use the variable importance metric of the fit CATE function to select 4 predictors to look closer at.
154+
varimp.cate = variable_importance(cate.forest)
155+
ranked.variables = order(varimp.cate, decreasing = TRUE)
156+
top.varnames = colnames(X)[ranked.variables[1:4]]
157+
print(top.varnames)
158+
159+
# Select the test set samples predicted to have low/high CATEs.
160+
# [We could also have used the full sample for this exercise.]
161+
low = samples.by.quartile[[1]]
162+
high = samples.by.quartile[[num.groups]]
163+
164+
# Make some long format data frames for ggplot.
165+
df.lo = data.frame(
166+
covariate.value = unlist(as.vector(X.test[low, top.varnames])),
167+
covariate.name = rep(top.varnames, each = length(low)),
168+
cate.estimates = "Low"
169+
)
170+
df.hi = data.frame(
171+
covariate.value = unlist(as.vector(X.test[high, top.varnames])),
172+
covariate.name = rep(top.varnames, each = length(high)),
173+
cate.estimates = "High"
174+
)
175+
df.plot.hist = rbind(df.lo, df.hi)
176+
177+
# Plot overlaid histograms of the selected covariates by low/high classification.
178+
ggplot(df.plot.hist, aes(x = covariate.value, fill = cate.estimates)) +
179+
geom_histogram(alpha = 0.7, position = "identity") +
180+
facet_wrap(~ covariate.name, scales = "free", ncol = 2)
181+
182+
183+
# *** Best linear projections (BLP) ****
184+
185+
# Select some potential effect modifier(s) we are interested in.
186+
blp.vars = c("X1", "X2", "X3")
187+
188+
# Estimate the best linear projection on our variables.
189+
best_linear_projection(cf.full, X[, blp.vars])
190+
191+
192+
# *** Risk vs CATE-based targeting ***
193+
194+
# In our application it was reasonable to hypothesize that soldiers with high
195+
# "risk" of developing PTSD also has a high treatment effect (i.e. low resilience).
196+
197+
# Train a risk model on the training set. First, select units with high combat stress.
198+
train.hi = train[W[train] == 0]
199+
200+
# Use a regression forest to estimate P[develops PTSD | X, high combat stress]
201+
# (We recorded Y = 1 if healthy and so 1 - Y is 1 if the outcome is PTSD)
202+
rf.risk = regression_forest(X[train.hi, ], 1 - Y[train.hi])
203+
risk.hat.test = predict(rf.risk, X.test)$predictions
204+
205+
# Compare risk vs CATE-based targeting using the AUTOC.
206+
rate.risk = rank_average_treatment_effect(
207+
eval.forest,
208+
cbind(tau.hat.test, risk.hat.test)
209+
)
210+
plot(rate.risk)
211+
print(rate.risk)
212+
213+
# Construct a 95% confidence interval for the AUTOCs as well as for AUTOC(cate.hat) - AUTOC(risk.hat).
214+
rate.risk$estimate + data.frame(lower = -1.96 * rate.risk$std.err,
215+
upper = 1.96 * rate.risk$std.err,
216+
row.names = rate.risk$target)
217+
218+
219+
# *** Appendix: Evaluate CATE models via the AUTOC ***
220+
221+
# Causal forest is a two-step algorithm that first accounts for confounding and baseline effects
222+
# via the propensity score e(x) and a conditional mean model m(x), then in the second step estimates
223+
# treatment effect heterogeneity. In some settings we may want to try using different covariates (or
224+
# possibly models) for e(x), m(x), and CATE predictions.
225+
226+
# Estimate m(x) = E[Y | X = x] using a regression forest.
227+
Y.forest = regression_forest(X[train, ], Y[train], num.trees = 500)
228+
Y.hat = predict(Y.forest)$predictions
229+
230+
# Estimate e(x) = E[W | X = x] using a regression forest.
231+
W.forest = regression_forest(X[train, ], W[train], num.trees = 500)
232+
W.hat = predict(W.forest)$predictions
233+
234+
# Select the covariates X with, for example, m(x) variable importance in the top 25%.
235+
varimp.Y = variable_importance(Y.forest)
236+
selected.vars = which(varimp.Y >= quantile(varimp.Y, 0.75))
237+
print(colnames(X)[selected.vars])
238+
239+
if (length(selected.vars) <= 1) stop("You should really try and use more than just one predictor variable with forests.")
240+
241+
# Try and fit a CATE model using this smaller set of potential heterogeneity predictors.
242+
X.subset = X[, selected.vars]
243+
cate.forest.restricted = causal_forest(X.subset[train, ], Y[train], W[train],
244+
Y.hat = Y.hat, W.hat = W.hat)
245+
# Predict CATEs on test set.
246+
tau.hat.test.restricted = predict(cate.forest.restricted, X.test[, selected.vars])$predictions
247+
248+
# Compare CATE models with AUTOC.
249+
rate.cate.compare = rank_average_treatment_effect(
250+
eval.forest,
251+
cbind(tau.hat.test, tau.hat.test.restricted)
252+
)
253+
# Get an estimate of the AUTOCs, as well as difference in AUTOC.
254+
print(rate.cate.compare)
255+
256+
# Get a p-value for the AUTOCs and difference in AUTOCs.
257+
data.frame(
258+
p.value = 2 * pnorm(-abs(rate.cate.compare$estimate / rate.cate.compare$std.err)),
259+
target = rate.cate.compare$target
260+
)
261+
262+
# Or equivalently, we could construct a 2-sided confidence interval.
263+
rate.cate.compare$estimate + data.frame(lower = -1.96 * rate.cate.compare$std.err,
264+
upper = 1.96 * rate.cate.compare$std.err,
265+
row.names = rate.cate.compare$target)
266+
267+
# [In this synthetic example the restricted CATE model does not do much better.]

0 commit comments

Comments
 (0)