From cab5f4097e66faf86985e86d1068a2c7bd025159 Mon Sep 17 00:00:00 2001 From: Molly Offer-Westort Date: Tue, 17 Jun 2025 19:47:47 -0500 Subject: [PATCH 1/3] diagnostics with cross-fitting --- r-package/grf/vignettes/diagnostics.Rmd | 36 +++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/r-package/grf/vignettes/diagnostics.Rmd b/r-package/grf/vignettes/diagnostics.Rmd index 12066ac7b..f84a0a13a 100644 --- a/r-package/grf/vignettes/diagnostics.Rmd +++ b/r-package/grf/vignettes/diagnostics.Rmd @@ -80,6 +80,42 @@ ate.high[["estimate"]] - ate.low[["estimate"]] + c(-1, 1) * qnorm(0.975) * sqrt(ate.high[["std.err"]]^2 + ate.low[["std.err"]]^2) ``` +While this approach may give some qualitative insight into heterogeneity, the grouping is naive, because the doubly robust scores used to determine subgroups are not independent of the scores used to estimate those group ATEs (see Athey and Wager, 2019). + +To avoid this, we can use a cross-fitting approach, where the data is split into two folds, and the "high"/"low" groups are determined by the models fit on the other fold, while the ATEs are estimated using the default out of bag predictions using the [average_treatment_effect](https://grf-labs.github.io/grf/reference/average_treatment_effect.html) function. + +```{r} +folds <- sample(rep(1:2, length.out = nrow(X))) +idxA <- which(folds == 1) +idxB <- which(folds == 2) + +cfA <- causal_forest(X[idxA,], Y[idxA], W[idxA]) +cfB <- causal_forest(X[idxB,], Y[idxB], W[idxB]) + +tau.hatB <- predict(cfA, newdata = X[idxB,])$predictions +high.effectB <- tau.hatB > median(tau.hatB) +tau.hatA <- predict(cfB, newdata = X[idxA,])$predictions +high.effectA <- tau.hatA > median(tau.hatA) + +ate.highA <- average_treatment_effect(cfA, subset = high.effectA) +ate.lowA <- average_treatment_effect(cfA, subset = !high.effectB) +ate.highB <- average_treatment_effect(cfB, subset = high.effectB) +ate.lowB <- average_treatment_effect(cfB, subset = !high.effectB) + +``` + +Which gives the following 95% confidence interval for the difference in ATE + +```{r} +mean(c( + ate.highA[["estimate"]] - ate.lowA[["estimate"]], + ate.highB[["estimate"]] - ate.lowB[["estimate"]])) + + c(-1, 1) * qnorm(0.975) * sqrt( + ate.highA[["std.err"]]^2 + ate.lowA[["std.err"]]^2 + + ate.highB[["std.err"]]^2 + ate.lowB[["std.err"]]^2 + ) +``` + For another way to assess heterogeneity, see the function [rank_average_treatment_effect](https://grf-labs.github.io/grf/reference/rank_average_treatment_effect.html) and the accompanying [vignette](https://grf-labs.github.io/grf/articles/rate.html). Athey et al. (2017) suggests a bias measure to gauge how much work the propensity and outcome models have to do to get an unbiased estimate, relative to looking at a simple difference-in-means: $bias(x) = (e(x) - p) \times (p(\mu(0, x) - \mu_0) + (1 - p) (\mu(1, x) - \mu_1)$. From 81d7b40be027695c0c531e2e458074bae251dfda Mon Sep 17 00:00:00 2001 From: Molly Offer-Westort Date: Tue, 17 Jun 2025 19:47:47 -0500 Subject: [PATCH 2/3] diagnostics with cross-fitting --- r-package/grf/vignettes/diagnostics.Rmd | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r-package/grf/vignettes/diagnostics.Rmd b/r-package/grf/vignettes/diagnostics.Rmd index f84a0a13a..8e943c03b 100644 --- a/r-package/grf/vignettes/diagnostics.Rmd +++ b/r-package/grf/vignettes/diagnostics.Rmd @@ -98,7 +98,7 @@ tau.hatA <- predict(cfB, newdata = X[idxA,])$predictions high.effectA <- tau.hatA > median(tau.hatA) ate.highA <- average_treatment_effect(cfA, subset = high.effectA) -ate.lowA <- average_treatment_effect(cfA, subset = !high.effectB) +ate.lowA <- average_treatment_effect(cfA, subset = !high.effectA) ate.highB <- average_treatment_effect(cfB, subset = high.effectB) ate.lowB <- average_treatment_effect(cfB, subset = !high.effectB) @@ -113,7 +113,7 @@ mean(c( c(-1, 1) * qnorm(0.975) * sqrt( ate.highA[["std.err"]]^2 + ate.lowA[["std.err"]]^2 + ate.highB[["std.err"]]^2 + ate.lowB[["std.err"]]^2 - ) + )/2 ``` For another way to assess heterogeneity, see the function [rank_average_treatment_effect](https://grf-labs.github.io/grf/reference/rank_average_treatment_effect.html) and the accompanying [vignette](https://grf-labs.github.io/grf/articles/rate.html). From e6a0a0dca7aea5d190352f90f984489af0c2f16a Mon Sep 17 00:00:00 2001 From: Molly Offer-Westort Date: Tue, 17 Jun 2025 19:47:47 -0500 Subject: [PATCH 3/3] diagnostics with cross-fitting --- r-package/grf/vignettes/diagnostics.Rmd | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/r-package/grf/vignettes/diagnostics.Rmd b/r-package/grf/vignettes/diagnostics.Rmd index 8e943c03b..73af05493 100644 --- a/r-package/grf/vignettes/diagnostics.Rmd +++ b/r-package/grf/vignettes/diagnostics.Rmd @@ -104,16 +104,15 @@ ate.lowB <- average_treatment_effect(cfB, subset = !high.effectB) ``` -Which gives the following 95% confidence interval for the difference in ATE +Which gives us 95% confidence intervals for the difference in ATE for each fold using the same approach as above. ```{r} -mean(c( - ate.highA[["estimate"]] - ate.lowA[["estimate"]], - ate.highB[["estimate"]] - ate.lowB[["estimate"]])) + - c(-1, 1) * qnorm(0.975) * sqrt( - ate.highA[["std.err"]]^2 + ate.lowA[["std.err"]]^2 + - ate.highB[["std.err"]]^2 + ate.lowB[["std.err"]]^2 - )/2 +ate.highA[["estimate"]] - ate.lowA[["estimate"]] + + c(-1, 1) * qnorm(0.975) * sqrt(ate.highA[["std.err"]]^2 + ate.lowA[["std.err"]]^2) + +ate.highB[["estimate"]] - ate.lowB[["estimate"]] + + c(-1, 1) * qnorm(0.975) * sqrt(ate.highB[["std.err"]]^2 + ate.lowB[["std.err"]]^2) + ``` For another way to assess heterogeneity, see the function [rank_average_treatment_effect](https://grf-labs.github.io/grf/reference/rank_average_treatment_effect.html) and the accompanying [vignette](https://grf-labs.github.io/grf/articles/rate.html).