|
| 1 | +rm(list = ls()) |
| 2 | +set.seed(42) |
| 3 | + |
| 4 | +library(grf) |
| 5 | +library(maq) # For Qini curves. |
| 6 | +library(ggplot2) |
| 7 | + |
| 8 | +# Read in data and specify outcome Y, treatment W, and (numeric) matrix of covariates X. |
| 9 | +data = read.csv("https://raw.githubusercontent.com/grf-labs/grf/master/experiments/ijmpr/synthetic_data.csv") |
| 10 | +Y = data$outcome |
| 11 | +W = data$treatment |
| 12 | +X = data[, -c(1, 2)] |
| 13 | + |
| 14 | + |
| 15 | +# This script assumes the covariates X have named columns. |
| 16 | +# If not provided, we make up some default names. |
| 17 | +if (is.null(colnames(X))) colnames(X) = make.names(1:ncol(X)) |
| 18 | + |
| 19 | + |
| 20 | +# *** Estimating an average treatment effect (ATE) *** |
| 21 | + |
| 22 | +# A simple difference in means estimate ignoring non-random assignment. |
| 23 | +summary(lm(Y ~ W)) |
| 24 | + |
| 25 | +# A doubly robust ATE estimate (forest-based AIPW). |
| 26 | +cf.full = causal_forest(X, Y, W) |
| 27 | +average_treatment_effect(cf.full) |
| 28 | + |
| 29 | +# A histogram of the estimated propensity scores. |
| 30 | +# Overlap requires that these don't get too close to either 0 or 1. |
| 31 | +hist(cf.full$W.hat, xlab = "Estimated propensity scores", main = "") |
| 32 | + |
| 33 | + |
| 34 | +# *** Estimating CATEs *** |
| 35 | + |
| 36 | +# Split data into a train and test sample. |
| 37 | +train = sample(nrow(X), 0.6 * nrow(X)) |
| 38 | +test = -train |
| 39 | + |
| 40 | +# Fit a CATE function on training data. |
| 41 | +cate.forest = causal_forest(X[train, ], Y[train], W[train]) |
| 42 | + |
| 43 | +# Predict CATEs on test set. |
| 44 | +X.test = X[test, ] |
| 45 | +tau.hat.test = predict(cate.forest, X.test)$predictions |
| 46 | + |
| 47 | +# A histogram of CATE estimates. |
| 48 | +hist(tau.hat.test, xlab = "Estimated CATEs", main = "") |
| 49 | + |
| 50 | +# On their own, the CATE point estimates are noisy. |
| 51 | +# What we often care about is whether they capture meaningful heterogeneity. |
| 52 | + |
| 53 | +# Here we construct groups according to which quartile of the predicted CATEs the unit belongs. |
| 54 | +# Then, we calculate ATEs in each of these groups and see if they differ. |
| 55 | +num.groups = 4 # 4 for quartiles, 5 for quintiles, etc. |
| 56 | +quartile = cut(tau.hat.test, |
| 57 | + quantile(tau.hat.test, seq(0, 1, by = 1 / num.groups)), |
| 58 | + labels = 1:num.groups, |
| 59 | + include.lowest = TRUE) |
| 60 | +# Create a list of test set samples by CATE quartile. |
| 61 | +samples.by.quartile = split(seq_along(quartile), quartile) |
| 62 | + |
| 63 | +# Look at ATEs in each of these quartiles. To calculate these we fit a separate evaluation forest. |
| 64 | +eval.forest = causal_forest(X.test, Y[test], W[test]) |
| 65 | + |
| 66 | +# Calculate doubly robust ATEs for each group. |
| 67 | +ate.by.quartile = lapply(samples.by.quartile, function(samples) { |
| 68 | + average_treatment_effect(eval.forest, subset = samples) |
| 69 | +}) |
| 70 | + |
| 71 | +# Plot group ATEs along with 95% confidence bars. |
| 72 | +df.plot.ate = data.frame( |
| 73 | + matrix(unlist(ate.by.quartile), num.groups, byrow = TRUE, dimnames = list(NULL, c("estimate","std.err"))), |
| 74 | + group = 1:num.groups |
| 75 | +) |
| 76 | + |
| 77 | +ggplot(df.plot.ate, aes(x = group, y = estimate)) + |
| 78 | + geom_point() + |
| 79 | + geom_errorbar(aes(ymin = estimate - 1.96 * std.err, ymax = estimate + 1.96 * std.err, width = 0.2)) + |
| 80 | + xlab("Estimated CATE quantile") + |
| 81 | + ylab("Average treatment effect") |
| 82 | + |
| 83 | + |
| 84 | +# *** Evaluate heterogeneity via TOC/AUTOC *** |
| 85 | + |
| 86 | +# In the previous section we split the data into quartiles. |
| 87 | +# The TOC is essentially a continuous version of this exercise. |
| 88 | + |
| 89 | +# Use eval.forest to form a doubly robust estimate of TOC/AUTOC. |
| 90 | +rate.cate = rank_average_treatment_effect( |
| 91 | + eval.forest, |
| 92 | + tau.hat.test, |
| 93 | + q = seq(0.05, 1, length.out = 100) |
| 94 | +) |
| 95 | +# Plot the TOC. |
| 96 | +plot(rate.cate) |
| 97 | + |
| 98 | +# An estimate and standard error of AUTOC. |
| 99 | +print(rate.cate) |
| 100 | + |
| 101 | +# Get a 2-sided p-value Pr(>|t|) for RATE = 0 using a t-value. |
| 102 | +2 * pnorm(-abs(rate.cate$estimate / rate.cate$std.err)) |
| 103 | + |
| 104 | +# [For alternatives to perform these tests without performing train/test splits, |
| 105 | +# including relying on one-sided tests, see the following vignette for more details |
| 106 | +# https://grf-labs.github.io/grf/articles/rate_cv.html] |
| 107 | + |
| 108 | + |
| 109 | +# *** Policy evaluation with Qini curves **** |
| 110 | + |
| 111 | +# We can use the `maq` package for this exercise. This package is more general |
| 112 | +# and accepts CATE estimates from multiple treatment arms along with costs that |
| 113 | +# denominate what we spend by assigning a unit a treatment. In this application |
| 114 | +# we can simply treat the number of units we are considering deploying as the cost. |
| 115 | + |
| 116 | +# Form a doubly robust estimate of a CATE-based Qini curve (using eval.forest). |
| 117 | +num.units = nrow(X) |
| 118 | +qini = maq(tau.hat.test, |
| 119 | + num.units, |
| 120 | + get_scores(eval.forest) * num.units, |
| 121 | + R = 200) |
| 122 | + |
| 123 | +# Form a baseline Qini curve that assigns treatment uniformly. |
| 124 | +qini.baseline = maq(tau.hat.test, |
| 125 | + num.units, |
| 126 | + get_scores(eval.forest) * num.units, |
| 127 | + R = 200, |
| 128 | + target.with.covariates = FALSE) |
| 129 | + |
| 130 | +# Plot the Qini curve along with 95% confidence lines. |
| 131 | +plot(qini, ylab = "PTSD cases prevented", xlab = "Units held back from deployment", xlim = c(0, num.units)) |
| 132 | +plot(qini.baseline, add = TRUE, ci.args = NULL) |
| 133 | + |
| 134 | +# Get estimates from the curve, at for example 500 deployed units. |
| 135 | +average_gain(qini, 500) |
| 136 | + |
| 137 | +# Compare the benefit of targeting the 500 units predicted to benefit the most with the baseline. |
| 138 | +difference_gain(qini, qini.baseline, 500) |
| 139 | + |
| 140 | + |
| 141 | +# [The paper shows Qini curves embellished with ggplot. We could have retrieved |
| 142 | +# the data underlying the curves and customized our plots further. |
| 143 | +# For more details we refer to https://github.com/grf-labs/maq] |
| 144 | + |
| 145 | + |
| 146 | +# *** Describing the fit CATE function **** |
| 147 | + |
| 148 | +# Our `cate.forest` has given us some estimated function \tau(x). |
| 149 | +# Let's have a closer look at how this function stratifies our sample in terms of "covariate" profiles. |
| 150 | +# One way to do so is to look at histograms of our covariates by for example low / high CATE predictions. |
| 151 | + |
| 152 | +# First, we'll use a simple heuristic to narrow down the number of predictors to look closer at. |
| 153 | +# Here we use the variable importance metric of the fit CATE function to select 4 predictors to look closer at. |
| 154 | +varimp.cate = variable_importance(cate.forest) |
| 155 | +ranked.variables = order(varimp.cate, decreasing = TRUE) |
| 156 | +top.varnames = colnames(X)[ranked.variables[1:4]] |
| 157 | +print(top.varnames) |
| 158 | + |
| 159 | +# Select the test set samples predicted to have low/high CATEs. |
| 160 | +# [We could also have used the full sample for this exercise.] |
| 161 | +low = samples.by.quartile[[1]] |
| 162 | +high = samples.by.quartile[[num.groups]] |
| 163 | + |
| 164 | +# Make some long format data frames for ggplot. |
| 165 | +df.lo = data.frame( |
| 166 | + covariate.value = unlist(as.vector(X.test[low, top.varnames])), |
| 167 | + covariate.name = rep(top.varnames, each = length(low)), |
| 168 | + cate.estimates = "Low" |
| 169 | +) |
| 170 | +df.hi = data.frame( |
| 171 | + covariate.value = unlist(as.vector(X.test[high, top.varnames])), |
| 172 | + covariate.name = rep(top.varnames, each = length(high)), |
| 173 | + cate.estimates = "High" |
| 174 | +) |
| 175 | +df.plot.hist = rbind(df.lo, df.hi) |
| 176 | + |
| 177 | +# Plot overlaid histograms of the selected covariates by low/high classification. |
| 178 | +ggplot(df.plot.hist, aes(x = covariate.value, fill = cate.estimates)) + |
| 179 | + geom_histogram(alpha = 0.7, position = "identity") + |
| 180 | + facet_wrap(~ covariate.name, scales = "free", ncol = 2) |
| 181 | + |
| 182 | + |
| 183 | +# *** Best linear projections (BLP) **** |
| 184 | + |
| 185 | +# Select some potential effect modifier(s) we are interested in. |
| 186 | +blp.vars = c("X1", "X2", "X3") |
| 187 | + |
| 188 | +# Estimate the best linear projection on our variables. |
| 189 | +best_linear_projection(cf.full, X[, blp.vars]) |
| 190 | + |
| 191 | + |
| 192 | +# *** Risk vs CATE-based targeting *** |
| 193 | + |
| 194 | +# In our application it was reasonable to hypothesize that soldiers with high |
| 195 | +# "risk" of developing PTSD also has a high treatment effect (i.e. low resilience). |
| 196 | + |
| 197 | +# Train a risk model on the training set. First, select units with high combat stress. |
| 198 | +train.hi = train[W[train] == 0] |
| 199 | + |
| 200 | +# Use a regression forest to estimate P[develops PTSD | X, high combat stress] |
| 201 | +# (We recorded Y = 1 if healthy and so 1 - Y is 1 if the outcome is PTSD) |
| 202 | +rf.risk = regression_forest(X[train.hi, ], 1 - Y[train.hi]) |
| 203 | +risk.hat.test = predict(rf.risk, X.test)$predictions |
| 204 | + |
| 205 | +# Compare risk vs CATE-based targeting using the AUTOC. |
| 206 | +rate.risk = rank_average_treatment_effect( |
| 207 | + eval.forest, |
| 208 | + cbind(tau.hat.test, risk.hat.test) |
| 209 | +) |
| 210 | +plot(rate.risk) |
| 211 | +print(rate.risk) |
| 212 | + |
| 213 | +# Construct a 95% confidence interval for the AUTOCs as well as for AUTOC(cate.hat) - AUTOC(risk.hat). |
| 214 | +rate.risk$estimate + data.frame(lower = -1.96 * rate.risk$std.err, |
| 215 | + upper = 1.96 * rate.risk$std.err, |
| 216 | + row.names = rate.risk$target) |
| 217 | + |
| 218 | + |
| 219 | +# *** Appendix: Evaluate CATE models via the AUTOC *** |
| 220 | + |
| 221 | +# Causal forest is a two-step algorithm that first accounts for confounding and baseline effects |
| 222 | +# via the propensity score e(x) and a conditional mean model m(x), then in the second step estimates |
| 223 | +# treatment effect heterogeneity. In some settings we may want to try using different covariates (or |
| 224 | +# possibly models) for e(x), m(x), and CATE predictions. |
| 225 | + |
| 226 | +# Estimate m(x) = E[Y | X = x] using a regression forest. |
| 227 | +Y.forest = regression_forest(X[train, ], Y[train], num.trees = 500) |
| 228 | +Y.hat = predict(Y.forest)$predictions |
| 229 | + |
| 230 | +# Estimate e(x) = E[W | X = x] using a regression forest. |
| 231 | +W.forest = regression_forest(X[train, ], W[train], num.trees = 500) |
| 232 | +W.hat = predict(W.forest)$predictions |
| 233 | + |
| 234 | +# Select the covariates X with, for example, m(x) variable importance in the top 25%. |
| 235 | +varimp.Y = variable_importance(Y.forest) |
| 236 | +selected.vars = which(varimp.Y >= quantile(varimp.Y, 0.75)) |
| 237 | +print(colnames(X)[selected.vars]) |
| 238 | + |
| 239 | +if (length(selected.vars) <= 1) stop("You should really try and use more than just one predictor variable with forests.") |
| 240 | + |
| 241 | +# Try and fit a CATE model using this smaller set of potential heterogeneity predictors. |
| 242 | +X.subset = X[, selected.vars] |
| 243 | +cate.forest.restricted = causal_forest(X.subset[train, ], Y[train], W[train], |
| 244 | + Y.hat = Y.hat, W.hat = W.hat) |
| 245 | +# Predict CATEs on test set. |
| 246 | +tau.hat.test.restricted = predict(cate.forest.restricted, X.test[, selected.vars])$predictions |
| 247 | + |
| 248 | +# Compare CATE models with AUTOC. |
| 249 | +rate.cate.compare = rank_average_treatment_effect( |
| 250 | + eval.forest, |
| 251 | + cbind(tau.hat.test, tau.hat.test.restricted) |
| 252 | +) |
| 253 | +# Get an estimate of the AUTOCs, as well as difference in AUTOC. |
| 254 | +print(rate.cate.compare) |
| 255 | + |
| 256 | +# Get a p-value for the AUTOCs and difference in AUTOCs. |
| 257 | +data.frame( |
| 258 | + p.value = 2 * pnorm(-abs(rate.cate.compare$estimate / rate.cate.compare$std.err)), |
| 259 | + target = rate.cate.compare$target |
| 260 | +) |
| 261 | + |
| 262 | +# Or equivalently, we could construct a 2-sided confidence interval. |
| 263 | +rate.cate.compare$estimate + data.frame(lower = -1.96 * rate.cate.compare$std.err, |
| 264 | + upper = 1.96 * rate.cate.compare$std.err, |
| 265 | + row.names = rate.cate.compare$target) |
| 266 | + |
| 267 | +# [In this synthetic example the restricted CATE model does not do much better.] |
0 commit comments