Skip to content

Commit 4401d69

Browse files
Merge branch 'main' into carl/metrics-take2
2 parents de716b0 + 2515de8 commit 4401d69

File tree

11 files changed

+3517
-8
lines changed

11 files changed

+3517
-8
lines changed

econml/dml/causal_forest.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from .._cate_estimator import LinearCateEstimator
1818
from .._shap import _shap_explain_multitask_model_cate
1919
from .._ortho_learner import _OrthoLearner
20+
from ..validate.sensitivity_analysis import (sensitivity_interval, RV, dml_sensitivity_values,
21+
sensitivity_summary)
2022

2123

2224
class _CausalForestFinalWrapper:
@@ -56,6 +58,11 @@ def fit(self, X, T, T_res, Y_res, sample_weight=None, freq_weight=None, sample_v
5658
T_res = T_res.reshape((-1, 1))
5759
if Y_res.ndim == 1:
5860
Y_res = Y_res.reshape((-1, 1))
61+
62+
# if binary/continuous treatment and single outcome, can calculate sensitivity params
63+
if not ((self._d_t and self._d_t[0] > 1) or (self._d_y and self._d_y[0] > 1)):
64+
self.sensitivity_params = dml_sensitivity_values(T_res, Y_res)
65+
5966
self._model.fit(fts, T_res, Y_res, sample_weight=sample_weight)
6067
# Fit a doubly robust average effect
6168
if self._discrete_treatment and self._drate:
@@ -811,6 +818,113 @@ def tune(self, Y, T, *, X=None, W=None,
811818

812819
return self
813820

821+
def sensitivity_summary(self, null_hypothesis=0, alpha=0.05, c_y=0.05, c_t=0.05, rho=1., decimals=3):
822+
"""
823+
Generate a summary of the sensitivity analysis for the ATE.
824+
825+
Parameters
826+
----------
827+
null_hypothesis: float, default 0
828+
The null_hypothesis value for the ATE.
829+
830+
alpha: float, default 0.05
831+
The significance level for the sensitivity interval.
832+
833+
c_y: float, default 0.05
834+
The level of confounding in the outcome. Ranges from 0 to 1.
835+
836+
c_d: float, default 0.05
837+
The level of confounding in the treatment. Ranges from 0 to 1.
838+
839+
decimals: int, default 3
840+
Number of decimal places to round each column to.
841+
842+
"""
843+
if (self._d_t and self._d_t[0] > 1) or (self._d_y and self._d_y[0] > 1):
844+
raise ValueError(
845+
"Sensitivity analysis for DML is not supported for multi-dimensional outcomes or treatments.")
846+
sensitivity_params = self._ortho_learner_model_final._model_final.sensitivity_params
847+
return sensitivity_summary(**sensitivity_params._asdict(), null_hypothesis=null_hypothesis, alpha=alpha,
848+
c_y=c_y, c_t=c_t, rho=rho, decimals=decimals)
849+
850+
def sensitivity_interval(self, alpha=0.05, c_y=0.05, c_t=0.05, rho=1., interval_type='ci'):
851+
"""
852+
Calculate the sensitivity interval for the ATE.
853+
854+
The sensitivity interval is the range of values for the ATE that are
855+
consistent with the observed data, given a specified level of confounding.
856+
857+
Can only be calculated when Y and T are single arrays, and T is binary or continuous.
858+
859+
Based on `Chernozhukov et al. (2022) <https://www.nber.org/papers/w30302>`_
860+
861+
Parameters
862+
----------
863+
alpha: float, default 0.05
864+
The significance level for the sensitivity interval.
865+
866+
c_y: float, default 0.05
867+
The level of confounding in the outcome. Ranges from 0 to 1.
868+
869+
c_d: float, default 0.05
870+
The level of confounding in the treatment. Ranges from 0 to 1.
871+
872+
interval_type: str, default 'ci'
873+
The type of interval to return. Can be 'ci' or 'theta'
874+
875+
Returns
876+
-------
877+
(lb, ub): tuple of floats
878+
sensitivity interval for the ATE
879+
"""
880+
if (self._d_t and self._d_t[0] > 1) or (self._d_y and self._d_y[0] > 1):
881+
raise ValueError(
882+
"Sensitivity analysis for DML is not supported for multi-dimensional outcomes or treatments.")
883+
sensitivity_params = self._ortho_learner_model_final._model_final.sensitivity_params
884+
return sensitivity_interval(**sensitivity_params._asdict(), alpha=alpha,
885+
c_y=c_y, c_t=c_t, rho=rho, interval_type=interval_type)
886+
887+
def robustness_value(self, null_hypothesis=0, alpha=0.05, interval_type='ci'):
888+
"""
889+
Calculate the robustness value for the ATE.
890+
891+
The robustness value is the level of confounding (between 0 and 1) in
892+
*both* the treatment and outcome that would result in enough omitted variable bias such that
893+
we can no longer reject the null hypothesis. When null_hypothesis is the default of 0, the robustness value
894+
has the interpretation that it is the level of confounding that would make the
895+
ATE statistically insignificant.
896+
897+
A higher value indicates a more robust estimate.
898+
899+
Returns 0 if the original interval already includes the null_hypothesis.
900+
901+
Can only be calculated when Y and T are single arrays, and T is binary or continuous.
902+
903+
Based on `Chernozhukov et al. (2022) <https://www.nber.org/papers/w30302>`_
904+
905+
Parameters
906+
----------
907+
null_hypothesis: float, default 0
908+
The null_hypothesis value for the ATE.
909+
910+
alpha: float, default 0.05
911+
The significance level for the robustness value.
912+
913+
interval_type: str, default 'ci'
914+
The type of interval to return. Can be 'ci' or 'theta'
915+
916+
Returns
917+
-------
918+
float
919+
The robustness value
920+
"""
921+
if (self._d_t and self._d_t[0] > 1) or (self._d_y and self._d_y[0] > 1):
922+
raise ValueError(
923+
"Sensitivity analysis for DML is not supported for multi-dimensional outcomes or treatments.")
924+
sensitivity_params = self._ortho_learner_model_final._model_final.sensitivity_params
925+
return RV(**sensitivity_params._asdict(), null_hypothesis=null_hypothesis,
926+
alpha=alpha, interval_type=interval_type)
927+
814928
# override only so that we can update the docstring to indicate support for `blb`
815929
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None,
816930
cache_values=False, inference='auto'):

econml/dml/dml.py

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
shape, get_feature_names_or_default, filter_none_kwargs)
2727
from .._shap import _shap_explain_model_cate
2828
from ..sklearn_extensions.model_selection import get_selector, SingleModelSelector
29+
from ..validate.sensitivity_analysis import (sensitivity_interval, RV, dml_sensitivity_values,
30+
sensitivity_summary)
2931

3032

3133
def _combine(X, W, n_samples):
@@ -128,10 +130,12 @@ def _make_first_stage_selector(model, is_discrete, random_state):
128130

129131

130132
class _FinalWrapper:
131-
def __init__(self, model_final, fit_cate_intercept, featurizer, use_weight_trick):
133+
def __init__(self, model_final, fit_cate_intercept, featurizer,
134+
use_weight_trick, allow_sensitivity_analysis=False):
132135
self._model = clone(model_final, safe=False)
133136
self._use_weight_trick = use_weight_trick
134137
self._original_featurizer = clone(featurizer, safe=False)
138+
self.allow_sensitivity_analysis = allow_sensitivity_analysis
135139
if self._use_weight_trick:
136140
self._fit_cate_intercept = False
137141
self._featurizer = self._original_featurizer
@@ -170,6 +174,14 @@ def fit(self, X, T, T_res, Y_res, sample_weight=None, freq_weight=None, sample_v
170174
self._d_t = shape(T_res)[1:]
171175
self._d_y = shape(Y_res)[1:]
172176
if not self._use_weight_trick:
177+
178+
# if binary/continuous treatment and single outcome, can calculate sensitivity params
179+
if self.allow_sensitivity_analysis and not (
180+
(self._d_t and self._d_t[0] > 1) or (
181+
self._d_y and self._d_y[0] > 1)
182+
):
183+
self.sensitivity_params = dml_sensitivity_values(T_res, Y_res)
184+
173185
fts = self._combine(X, T_res)
174186
filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight,
175187
freq_weight=freq_weight, sample_var=sample_var)
@@ -558,7 +570,8 @@ def _gen_model_final(self):
558570
return clone(self.model_final, safe=False)
559571

560572
def _gen_rlearner_model_final(self):
561-
return _FinalWrapper(self._gen_model_final(), self.fit_cate_intercept, self._gen_featurizer(), False)
573+
return _FinalWrapper(self._gen_model_final(), self.fit_cate_intercept,
574+
self._gen_featurizer(), False, allow_sensitivity_analysis=True)
562575

563576
# override only so that we can update the docstring to indicate support for `LinearModelFinalInference`
564577
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
@@ -617,6 +630,114 @@ def bias_part_of_coef(self):
617630
def fit_cate_intercept_(self):
618631
return self.rlearner_model_final_._fit_cate_intercept
619632

633+
def sensitivity_summary(self, null_hypothesis=0, alpha=0.05, c_y=0.05, c_t=0.05, rho=1., decimals=3):
634+
"""
635+
Generate a summary of the sensitivity analysis for the ATE.
636+
637+
Parameters
638+
----------
639+
null_hypothesis: float, default 0
640+
The null_hypothesis value for the ATE.
641+
642+
alpha: float, default 0.05
643+
The significance level for the sensitivity interval.
644+
645+
c_y: float, default 0.05
646+
The level of confounding in the outcome. Ranges from 0 to 1.
647+
648+
c_d: float, default 0.05
649+
The level of confounding in the treatment. Ranges from 0 to 1.
650+
651+
decimals: int, default 3
652+
Number of decimal places to round each column to.
653+
654+
"""
655+
if (self._d_t and self._d_t[0] > 1) or (self._d_y and self._d_y[0] > 1):
656+
raise ValueError(
657+
"Sensitivity analysis for DML is not supported for multi-dimensional outcomes or treatments.")
658+
sensitivity_params = self._ortho_learner_model_final._model_final.sensitivity_params
659+
return sensitivity_summary(**sensitivity_params._asdict(), null_hypothesis=null_hypothesis, alpha=alpha,
660+
c_y=c_y, c_t=c_t, rho=rho, decimals=decimals)
661+
662+
663+
def sensitivity_interval(self, alpha=0.05, c_y=0.05, c_t=0.05, rho=1., interval_type='ci'):
664+
"""
665+
Calculate the sensitivity interval for the ATE.
666+
667+
The sensitivity interval is the range of values for the ATE that are
668+
consistent with the observed data, given a specified level of confounding.
669+
670+
Can only be calculated when Y and T are single arrays, and T is binary or continuous.
671+
672+
Based on `Chernozhukov et al. (2022) <https://www.nber.org/papers/w30302>`_
673+
674+
Parameters
675+
----------
676+
alpha: float, default 0.05
677+
The significance level for the sensitivity interval.
678+
679+
c_y: float, default 0.05
680+
The level of confounding in the outcome. Ranges from 0 to 1.
681+
682+
c_d: float, default 0.05
683+
The level of confounding in the treatment. Ranges from 0 to 1.
684+
685+
interval_type: str, default 'ci'
686+
The type of interval to return. Can be 'ci' or 'theta'
687+
688+
Returns
689+
-------
690+
(lb, ub): tuple of floats
691+
sensitivity interval for the ATE
692+
"""
693+
if (self._d_t and self._d_t[0] > 1) or (self._d_y and self._d_y[0] > 1):
694+
raise ValueError(
695+
"Sensitivity analysis for DML is not supported for multi-dimensional outcomes or treatments.")
696+
sensitivity_params = self._ortho_learner_model_final._model_final.sensitivity_params
697+
return sensitivity_interval(**sensitivity_params._asdict(), alpha=alpha,
698+
c_y=c_y, c_t=c_t, rho=rho, interval_type=interval_type)
699+
700+
def robustness_value(self, null_hypothesis=0, alpha=0.05, interval_type='ci'):
701+
"""
702+
Calculate the robustness value for the ATE.
703+
704+
The robustness value is the level of confounding (between 0 and 1) in
705+
*both* the treatment and outcome that would result in enough omitted variable bias such that
706+
we can no longer reject the null hypothesis. When null_hypothesis is the default of 0, the robustness value
707+
has the interpretation that it is the level of confounding that would make the
708+
ATE statistically insignificant.
709+
710+
A higher value indicates a more robust estimate.
711+
712+
Returns 0 if the original interval already includes the null_hypothesis.
713+
714+
Can only be calculated when Y and T are single arrays, and T is binary or continuous.
715+
716+
Based on `Chernozhukov et al. (2022) <https://www.nber.org/papers/w30302>`_
717+
718+
Parameters
719+
----------
720+
null_hypothesis: float, default 0
721+
The null_hypothesis value for the ATE.
722+
723+
alpha: float, default 0.05
724+
The significance level for the robustness value.
725+
726+
interval_type: str, default 'ci'
727+
The type of interval to return. Can be 'ci' or 'theta'
728+
729+
Returns
730+
-------
731+
float
732+
The robustness value
733+
"""
734+
if (self._d_t and self._d_t[0] > 1) or (self._d_y and self._d_y[0] > 1):
735+
raise ValueError(
736+
"Sensitivity analysis for DML is not supported for multi-dimensional outcomes or treatments.")
737+
sensitivity_params = self._ortho_learner_model_final._model_final.sensitivity_params
738+
return RV(**sensitivity_params._asdict(), null_hypothesis=null_hypothesis,
739+
alpha=alpha, interval_type=interval_type)
740+
620741

621742
class LinearDML(StatsModelsCateEstimatorMixin, DML):
622743
"""

0 commit comments

Comments
 (0)