|
23 | 23 | from patsy import build_design_matrices, dmatrices
|
24 | 24 | from sklearn.linear_model import LinearRegression as sk_lin_reg
|
25 | 25 |
|
26 |
| -from causalpy.custom_exceptions import ( |
27 |
| - BadIndexException, # NOQA |
28 |
| - DataException, |
29 |
| - FormulaException, |
| 26 | +from causalpy.data_validation import ( |
| 27 | + PrePostFitDataValidator, |
| 28 | + DiDDataValidator, |
| 29 | + RDDataValidator, |
| 30 | + RegressionKinkDataValidator, |
| 31 | + PrePostNEGDDataValidator, |
| 32 | + IVDataValidator, |
30 | 33 | )
|
31 | 34 | from causalpy.plot_utils import plot_xY
|
32 |
| -from causalpy.utils import _is_variable_dummy_coded, round_num |
| 35 | +from causalpy.utils import round_num |
33 | 36 |
|
34 | 37 | LEGEND_FONT_SIZE = 12
|
35 | 38 | az.style.use("arviz-darkgrid")
|
@@ -108,7 +111,7 @@ def print_coefficients(self, round_to=None) -> None:
|
108 | 111 | )
|
109 | 112 |
|
110 | 113 |
|
111 |
| -class PrePostFit(ExperimentalDesign): |
| 114 | +class PrePostFit(ExperimentalDesign, PrePostFitDataValidator): |
112 | 115 | """
|
113 | 116 | A class to analyse quasi-experiments where parameter estimation is based on just
|
114 | 117 | the pre-intervention data.
|
@@ -160,7 +163,6 @@ def __init__(
|
160 | 163 | ) -> None:
|
161 | 164 | super().__init__(model=model, **kwargs)
|
162 | 165 | self._input_validation(data, treatment_time)
|
163 |
| - |
164 | 166 | self.treatment_time = treatment_time
|
165 | 167 | # set experiment type - usually done in subclasses
|
166 | 168 | self.expt_type = "Pre-Post Fit"
|
@@ -214,21 +216,6 @@ def __init__(
|
214 | 216 | # cumulative impact post
|
215 | 217 | self.post_impact_cumulative = self.post_impact.cumsum(dim="obs_ind")
|
216 | 218 |
|
217 |
| - def _input_validation(self, data, treatment_time): |
218 |
| - """Validate the input data and model formula for correctness""" |
219 |
| - if isinstance(data.index, pd.DatetimeIndex) and not isinstance( |
220 |
| - treatment_time, pd.Timestamp |
221 |
| - ): |
222 |
| - raise BadIndexException( |
223 |
| - "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp." |
224 |
| - ) |
225 |
| - if not isinstance(data.index, pd.DatetimeIndex) and isinstance( |
226 |
| - treatment_time, pd.Timestamp |
227 |
| - ): |
228 |
| - raise BadIndexException( |
229 |
| - "If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501 |
230 |
| - ) |
231 |
| - |
232 | 219 | def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
|
233 | 220 | """
|
234 | 221 | Plot the results
|
@@ -438,7 +425,7 @@ def plot(self, plot_predictors=False, **kwargs):
|
438 | 425 | return fig, ax
|
439 | 426 |
|
440 | 427 |
|
441 |
| -class DifferenceInDifferences(ExperimentalDesign): |
| 428 | +class DifferenceInDifferences(ExperimentalDesign, DiDDataValidator): |
442 | 429 | """A class to analyse data from Difference in Difference settings.
|
443 | 430 |
|
444 | 431 | .. note::
|
@@ -568,29 +555,6 @@ def __init__(
|
568 | 555 | if "post_treatment" in label and self.group_variable_name in label:
|
569 | 556 | self.causal_impact = self.idata.posterior["beta"].isel({"coeffs": i})
|
570 | 557 |
|
571 |
| - def _input_validation(self): |
572 |
| - """Validate the input data and model formula for correctness""" |
573 |
| - if "post_treatment" not in self.formula: |
574 |
| - raise FormulaException( |
575 |
| - "A predictor called `post_treatment` should be in the formula" |
576 |
| - ) |
577 |
| - |
578 |
| - if "post_treatment" not in self.data.columns: |
579 |
| - raise DataException( |
580 |
| - "Require a boolean column labelling observations which are `treated`" |
581 |
| - ) |
582 |
| - |
583 |
| - if "unit" not in self.data.columns: |
584 |
| - raise DataException( |
585 |
| - "Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501 |
586 |
| - ) |
587 |
| - |
588 |
| - if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False: |
589 |
| - raise DataException( |
590 |
| - f"""The grouping variable {self.group_variable_name} should be dummy |
591 |
| - coded. Consisting of 0's and 1's only.""" |
592 |
| - ) |
593 |
| - |
594 | 558 | def plot(self, round_to=None):
|
595 | 559 | """Plot the results.
|
596 | 560 |
|
@@ -749,7 +713,7 @@ def summary(self, round_to=None) -> None:
|
749 | 713 | self.print_coefficients(round_to)
|
750 | 714 |
|
751 | 715 |
|
752 |
| -class RegressionDiscontinuity(ExperimentalDesign): |
| 716 | +class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator): |
753 | 717 | """
|
754 | 718 | A class to analyse sharp regression discontinuity experiments.
|
755 | 719 |
|
@@ -876,18 +840,6 @@ def __init__(
|
876 | 840 | - self.pred_discon["posterior_predictive"].sel(obs_ind=0)["mu"]
|
877 | 841 | )
|
878 | 842 |
|
879 |
| - def _input_validation(self): |
880 |
| - """Validate the input data and model formula for correctness""" |
881 |
| - if "treated" not in self.formula: |
882 |
| - raise FormulaException( |
883 |
| - "A predictor called `treated` should be in the formula" |
884 |
| - ) |
885 |
| - |
886 |
| - if _is_variable_dummy_coded(self.data["treated"]) is False: |
887 |
| - raise DataException( |
888 |
| - """The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501 |
889 |
| - ) |
890 |
| - |
891 | 843 | def _is_treated(self, x):
|
892 | 844 | """Returns ``True`` if `x` is greater than or equal to the treatment threshold.
|
893 | 845 |
|
@@ -970,7 +922,7 @@ def summary(self, round_to: None) -> None:
|
970 | 922 | self.print_coefficients(round_to)
|
971 | 923 |
|
972 | 924 |
|
973 |
| -class RegressionKink(ExperimentalDesign): |
| 925 | +class RegressionKink(ExperimentalDesign, RegressionKinkDataValidator): |
974 | 926 | """
|
975 | 927 | A class to analyse sharp regression kink experiments.
|
976 | 928 |
|
@@ -1095,24 +1047,6 @@ def _probe_kink_point(self):
|
1095 | 1047 | mu_kink_right = predicted["posterior_predictive"].sel(obs_ind=2)["mu"]
|
1096 | 1048 | return mu_kink_left, mu_kink, mu_kink_right
|
1097 | 1049 |
|
1098 |
| - def _input_validation(self): |
1099 |
| - """Validate the input data and model formula for correctness""" |
1100 |
| - if "treated" not in self.formula: |
1101 |
| - raise FormulaException( |
1102 |
| - "A predictor called `treated` should be in the formula" |
1103 |
| - ) |
1104 |
| - |
1105 |
| - if _is_variable_dummy_coded(self.data["treated"]) is False: |
1106 |
| - raise DataException( |
1107 |
| - """The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501 |
1108 |
| - ) |
1109 |
| - |
1110 |
| - if self.bandwidth <= 0: |
1111 |
| - raise ValueError("The bandwidth must be greater than zero.") |
1112 |
| - |
1113 |
| - if self.epsilon <= 0: |
1114 |
| - raise ValueError("Epsilon must be greater than zero.") |
1115 |
| - |
1116 | 1050 | def _is_treated(self, x):
|
1117 | 1051 | """Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
|
1118 | 1052 | return np.greater_equal(x, self.kink_point)
|
@@ -1193,7 +1127,7 @@ def summary(self, round_to=None) -> None:
|
1193 | 1127 | self.print_coefficients(round_to)
|
1194 | 1128 |
|
1195 | 1129 |
|
1196 |
| -class PrePostNEGD(ExperimentalDesign): |
| 1130 | +class PrePostNEGD(ExperimentalDesign, PrePostNEGDDataValidator): |
1197 | 1131 | """
|
1198 | 1132 | A class to analyse data from pretest/posttest designs
|
1199 | 1133 |
|
@@ -1302,18 +1236,6 @@ def __init__(
|
1302 | 1236 | {"coeffs": self._get_treatment_effect_coeff()}
|
1303 | 1237 | )
|
1304 | 1238 |
|
1305 |
| - # ================================================================ |
1306 |
| - |
1307 |
| - def _input_validation(self) -> None: |
1308 |
| - """Validate the input data and model formula for correctness""" |
1309 |
| - if not _is_variable_dummy_coded(self.data[self.group_variable_name]): |
1310 |
| - raise DataException( |
1311 |
| - f""" |
1312 |
| - There must be 2 levels of the grouping variable |
1313 |
| - {self.group_variable_name}. I.e. the treated and untreated. |
1314 |
| - """ |
1315 |
| - ) |
1316 |
| - |
1317 | 1239 | def plot(self, round_to=None):
|
1318 | 1240 | """Plot the results
|
1319 | 1241 |
|
@@ -1408,7 +1330,7 @@ def _get_treatment_effect_coeff(self) -> str:
|
1408 | 1330 | raise NameError("Unable to find coefficient name for the treatment effect")
|
1409 | 1331 |
|
1410 | 1332 |
|
1411 |
| -class InstrumentalVariable(ExperimentalDesign): |
| 1333 | +class InstrumentalVariable(ExperimentalDesign, IVDataValidator): |
1412 | 1334 | """
|
1413 | 1335 | A class to analyse instrumental variable style experiments.
|
1414 | 1336 |
|
@@ -1555,26 +1477,3 @@ def get_naive_OLS_fit(self):
|
1555 | 1477 | beta_params.insert(0, ols_reg.intercept_[0])
|
1556 | 1478 | self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params))
|
1557 | 1479 | self.ols_reg = ols_reg
|
1558 |
| - |
1559 |
| - def _input_validation(self): |
1560 |
| - """Validate the input data and model formula for correctness""" |
1561 |
| - treatment = self.instruments_formula.split("~")[0] |
1562 |
| - test = treatment.strip() in self.instruments_data.columns |
1563 |
| - test = test & (treatment.strip() in self.data.columns) |
1564 |
| - if not test: |
1565 |
| - raise DataException( |
1566 |
| - f""" |
1567 |
| - The treatment variable: |
1568 |
| - {treatment} must appear in the instrument_data to be used |
1569 |
| - as an outcome variable and in the data object to be used as a covariate. |
1570 |
| - """ |
1571 |
| - ) |
1572 |
| - Z = self.data[treatment.strip()] |
1573 |
| - check_binary = len(np.unique(Z)) > 2 |
1574 |
| - if check_binary: |
1575 |
| - warnings.warn( |
1576 |
| - """Warning. The treatment variable is not Binary. |
1577 |
| - This is not necessarily a problem but it violates |
1578 |
| - the assumption of a simple IV experiment. |
1579 |
| - The coefficients should be interpreted appropriately.""" |
1580 |
| - ) |
0 commit comments