|
26 | 26 | shape, get_feature_names_or_default, filter_none_kwargs)
|
27 | 27 | from .._shap import _shap_explain_model_cate
|
28 | 28 | from ..sklearn_extensions.model_selection import get_selector, SingleModelSelector
|
| 29 | +from ..validate.sensitivity_analysis import (sensitivity_interval, RV, dml_sensitivity_values, |
| 30 | + sensitivity_summary) |
29 | 31 |
|
30 | 32 |
|
31 | 33 | def _combine(X, W, n_samples):
|
@@ -128,10 +130,12 @@ def _make_first_stage_selector(model, is_discrete, random_state):
|
128 | 130 |
|
129 | 131 |
|
130 | 132 | class _FinalWrapper:
|
131 |
| - def __init__(self, model_final, fit_cate_intercept, featurizer, use_weight_trick): |
| 133 | + def __init__(self, model_final, fit_cate_intercept, featurizer, |
| 134 | + use_weight_trick, allow_sensitivity_analysis=False): |
132 | 135 | self._model = clone(model_final, safe=False)
|
133 | 136 | self._use_weight_trick = use_weight_trick
|
134 | 137 | self._original_featurizer = clone(featurizer, safe=False)
|
| 138 | + self.allow_sensitivity_analysis = allow_sensitivity_analysis |
135 | 139 | if self._use_weight_trick:
|
136 | 140 | self._fit_cate_intercept = False
|
137 | 141 | self._featurizer = self._original_featurizer
|
@@ -170,6 +174,14 @@ def fit(self, X, T, T_res, Y_res, sample_weight=None, freq_weight=None, sample_v
|
170 | 174 | self._d_t = shape(T_res)[1:]
|
171 | 175 | self._d_y = shape(Y_res)[1:]
|
172 | 176 | if not self._use_weight_trick:
|
| 177 | + |
| 178 | + # if binary/continuous treatment and single outcome, can calculate sensitivity params |
| 179 | + if self.allow_sensitivity_analysis and not ( |
| 180 | + (self._d_t and self._d_t[0] > 1) or ( |
| 181 | + self._d_y and self._d_y[0] > 1) |
| 182 | + ): |
| 183 | + self.sensitivity_params = dml_sensitivity_values(T_res, Y_res) |
| 184 | + |
173 | 185 | fts = self._combine(X, T_res)
|
174 | 186 | filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight,
|
175 | 187 | freq_weight=freq_weight, sample_var=sample_var)
|
@@ -558,7 +570,8 @@ def _gen_model_final(self):
|
558 | 570 | return clone(self.model_final, safe=False)
|
559 | 571 |
|
560 | 572 | def _gen_rlearner_model_final(self):
|
561 |
| - return _FinalWrapper(self._gen_model_final(), self.fit_cate_intercept, self._gen_featurizer(), False) |
| 573 | + return _FinalWrapper(self._gen_model_final(), self.fit_cate_intercept, |
| 574 | + self._gen_featurizer(), False, allow_sensitivity_analysis=True) |
562 | 575 |
|
563 | 576 | # override only so that we can update the docstring to indicate support for `LinearModelFinalInference`
|
564 | 577 | def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
|
@@ -617,6 +630,114 @@ def bias_part_of_coef(self):
|
617 | 630 | def fit_cate_intercept_(self):
|
618 | 631 | return self.rlearner_model_final_._fit_cate_intercept
|
619 | 632 |
|
| 633 | + def sensitivity_summary(self, null_hypothesis=0, alpha=0.05, c_y=0.05, c_t=0.05, rho=1., decimals=3): |
| 634 | + """ |
| 635 | + Generate a summary of the sensitivity analysis for the ATE. |
| 636 | +
|
| 637 | + Parameters |
| 638 | + ---------- |
| 639 | + null_hypothesis: float, default 0 |
| 640 | + The null_hypothesis value for the ATE. |
| 641 | +
|
| 642 | + alpha: float, default 0.05 |
| 643 | + The significance level for the sensitivity interval. |
| 644 | +
|
| 645 | + c_y: float, default 0.05 |
| 646 | + The level of confounding in the outcome. Ranges from 0 to 1. |
| 647 | +
|
| 648 | + c_d: float, default 0.05 |
| 649 | + The level of confounding in the treatment. Ranges from 0 to 1. |
| 650 | +
|
| 651 | + decimals: int, default 3 |
| 652 | + Number of decimal places to round each column to. |
| 653 | +
|
| 654 | + """ |
| 655 | + if (self._d_t and self._d_t[0] > 1) or (self._d_y and self._d_y[0] > 1): |
| 656 | + raise ValueError( |
| 657 | + "Sensitivity analysis for DML is not supported for multi-dimensional outcomes or treatments.") |
| 658 | + sensitivity_params = self._ortho_learner_model_final._model_final.sensitivity_params |
| 659 | + return sensitivity_summary(**sensitivity_params._asdict(), null_hypothesis=null_hypothesis, alpha=alpha, |
| 660 | + c_y=c_y, c_t=c_t, rho=rho, decimals=decimals) |
| 661 | + |
| 662 | + |
| 663 | + def sensitivity_interval(self, alpha=0.05, c_y=0.05, c_t=0.05, rho=1., interval_type='ci'): |
| 664 | + """ |
| 665 | + Calculate the sensitivity interval for the ATE. |
| 666 | +
|
| 667 | + The sensitivity interval is the range of values for the ATE that are |
| 668 | + consistent with the observed data, given a specified level of confounding. |
| 669 | +
|
| 670 | + Can only be calculated when Y and T are single arrays, and T is binary or continuous. |
| 671 | +
|
| 672 | + Based on `Chernozhukov et al. (2022) <https://www.nber.org/papers/w30302>`_ |
| 673 | +
|
| 674 | + Parameters |
| 675 | + ---------- |
| 676 | + alpha: float, default 0.05 |
| 677 | + The significance level for the sensitivity interval. |
| 678 | +
|
| 679 | + c_y: float, default 0.05 |
| 680 | + The level of confounding in the outcome. Ranges from 0 to 1. |
| 681 | +
|
| 682 | + c_d: float, default 0.05 |
| 683 | + The level of confounding in the treatment. Ranges from 0 to 1. |
| 684 | +
|
| 685 | + interval_type: str, default 'ci' |
| 686 | + The type of interval to return. Can be 'ci' or 'theta' |
| 687 | +
|
| 688 | + Returns |
| 689 | + ------- |
| 690 | + (lb, ub): tuple of floats |
| 691 | + sensitivity interval for the ATE |
| 692 | + """ |
| 693 | + if (self._d_t and self._d_t[0] > 1) or (self._d_y and self._d_y[0] > 1): |
| 694 | + raise ValueError( |
| 695 | + "Sensitivity analysis for DML is not supported for multi-dimensional outcomes or treatments.") |
| 696 | + sensitivity_params = self._ortho_learner_model_final._model_final.sensitivity_params |
| 697 | + return sensitivity_interval(**sensitivity_params._asdict(), alpha=alpha, |
| 698 | + c_y=c_y, c_t=c_t, rho=rho, interval_type=interval_type) |
| 699 | + |
| 700 | + def robustness_value(self, null_hypothesis=0, alpha=0.05, interval_type='ci'): |
| 701 | + """ |
| 702 | + Calculate the robustness value for the ATE. |
| 703 | +
|
| 704 | + The robustness value is the level of confounding (between 0 and 1) in |
| 705 | + *both* the treatment and outcome that would result in enough omitted variable bias such that |
| 706 | + we can no longer reject the null hypothesis. When null_hypothesis is the default of 0, the robustness value |
| 707 | + has the interpretation that it is the level of confounding that would make the |
| 708 | + ATE statistically insignificant. |
| 709 | +
|
| 710 | + A higher value indicates a more robust estimate. |
| 711 | +
|
| 712 | + Returns 0 if the original interval already includes the null_hypothesis. |
| 713 | +
|
| 714 | + Can only be calculated when Y and T are single arrays, and T is binary or continuous. |
| 715 | +
|
| 716 | + Based on `Chernozhukov et al. (2022) <https://www.nber.org/papers/w30302>`_ |
| 717 | +
|
| 718 | + Parameters |
| 719 | + ---------- |
| 720 | + null_hypothesis: float, default 0 |
| 721 | + The null_hypothesis value for the ATE. |
| 722 | +
|
| 723 | + alpha: float, default 0.05 |
| 724 | + The significance level for the robustness value. |
| 725 | +
|
| 726 | + interval_type: str, default 'ci' |
| 727 | + The type of interval to return. Can be 'ci' or 'theta' |
| 728 | +
|
| 729 | + Returns |
| 730 | + ------- |
| 731 | + float |
| 732 | + The robustness value |
| 733 | + """ |
| 734 | + if (self._d_t and self._d_t[0] > 1) or (self._d_y and self._d_y[0] > 1): |
| 735 | + raise ValueError( |
| 736 | + "Sensitivity analysis for DML is not supported for multi-dimensional outcomes or treatments.") |
| 737 | + sensitivity_params = self._ortho_learner_model_final._model_final.sensitivity_params |
| 738 | + return RV(**sensitivity_params._asdict(), null_hypothesis=null_hypothesis, |
| 739 | + alpha=alpha, interval_type=interval_type) |
| 740 | + |
620 | 741 |
|
621 | 742 | class LinearDML(StatsModelsCateEstimatorMixin, DML):
|
622 | 743 | """
|
|
0 commit comments