Skip to content

Commit a63a7f9

Browse files
authored
Merge pull request #290 from pymc-labs/data-validation
Expand data validation to also cover the scikit-learn experiment classes
2 parents 67d9e0e + 6574481 commit a63a7f9

File tree

7 files changed

+237
-121
lines changed

7 files changed

+237
-121
lines changed

causalpy/data_validation.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import warnings # noqa: I001
2+
3+
import pandas as pd
4+
import numpy as np
5+
from causalpy.custom_exceptions import (
6+
BadIndexException, # NOQA
7+
DataException,
8+
FormulaException,
9+
)
10+
from causalpy.utils import _is_variable_dummy_coded
11+
12+
13+
class PrePostFitDataValidator:
14+
"""Mixin class for validating the input data and model formula for PrePostFit"""
15+
16+
def _input_validation(self, data, treatment_time):
17+
"""Validate the input data and model formula for correctness"""
18+
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
19+
treatment_time, pd.Timestamp
20+
):
21+
raise BadIndexException(
22+
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
23+
)
24+
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
25+
treatment_time, pd.Timestamp
26+
):
27+
raise BadIndexException(
28+
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
29+
)
30+
31+
32+
class DiDDataValidator:
33+
"""Mixin class for validating the input data and model formula for Difference in Differences experiments."""
34+
35+
def _input_validation(self):
36+
"""Validate the input data and model formula for correctness"""
37+
if "post_treatment" not in self.formula:
38+
raise FormulaException(
39+
"A predictor called `post_treatment` should be in the formula"
40+
)
41+
42+
if "post_treatment" not in self.data.columns:
43+
raise DataException(
44+
"Require a boolean column labelling observations which are `treated`"
45+
)
46+
47+
if "unit" not in self.data.columns:
48+
raise DataException(
49+
"Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501
50+
)
51+
52+
if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False:
53+
raise DataException(
54+
f"""The grouping variable {self.group_variable_name} should be dummy
55+
coded. Consisting of 0's and 1's only."""
56+
)
57+
58+
59+
class RDDataValidator:
60+
"""Mixin class for validating the input data and model formula for Regression Discontinuity experiments."""
61+
62+
def _input_validation(self):
63+
"""Validate the input data and model formula for correctness"""
64+
if "treated" not in self.formula:
65+
raise FormulaException(
66+
"A predictor called `treated` should be in the formula"
67+
)
68+
69+
if _is_variable_dummy_coded(self.data["treated"]) is False:
70+
raise DataException(
71+
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
72+
)
73+
74+
75+
class RegressionKinkDataValidator:
76+
"""Mixin class for validating the input data and model formula for Regression Kink experiments."""
77+
78+
def _input_validation(self):
79+
"""Validate the input data and model formula for correctness"""
80+
if "treated" not in self.formula:
81+
raise FormulaException(
82+
"A predictor called `treated` should be in the formula"
83+
)
84+
85+
if _is_variable_dummy_coded(self.data["treated"]) is False:
86+
raise DataException(
87+
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
88+
)
89+
90+
if self.bandwidth <= 0:
91+
raise ValueError("The bandwidth must be greater than zero.")
92+
93+
if self.epsilon <= 0:
94+
raise ValueError("Epsilon must be greater than zero.")
95+
96+
97+
class PrePostNEGDDataValidator:
98+
"""Mixin class for validating the input data and model formula for PrePostNEGD experiments."""
99+
100+
def _input_validation(self) -> None:
101+
"""Validate the input data and model formula for correctness"""
102+
if not _is_variable_dummy_coded(self.data[self.group_variable_name]):
103+
raise DataException(
104+
f"""
105+
There must be 2 levels of the grouping variable
106+
{self.group_variable_name}. I.e. the treated and untreated.
107+
"""
108+
)
109+
110+
111+
class IVDataValidator:
112+
"""Mixin class for validating the input data and model formula for IV experiments."""
113+
114+
def _input_validation(self):
115+
"""Validate the input data and model formula for correctness"""
116+
treatment = self.instruments_formula.split("~")[0]
117+
test = treatment.strip() in self.instruments_data.columns
118+
test = test & (treatment.strip() in self.data.columns)
119+
if not test:
120+
raise DataException(
121+
f"""
122+
The treatment variable:
123+
{treatment} must appear in the instrument_data to be used
124+
as an outcome variable and in the data object to be used as a covariate.
125+
"""
126+
)
127+
Z = self.data[treatment.strip()]
128+
check_binary = len(np.unique(Z)) > 2
129+
if check_binary:
130+
warnings.warn(
131+
"""Warning. The treatment variable is not Binary.
132+
This is not necessarily a problem but it violates
133+
the assumption of a simple IV experiment.
134+
The coefficients should be interpreted appropriately."""
135+
)

causalpy/pymc_experiments.py

Lines changed: 14 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@
2323
from patsy import build_design_matrices, dmatrices
2424
from sklearn.linear_model import LinearRegression as sk_lin_reg
2525

26-
from causalpy.custom_exceptions import (
27-
BadIndexException, # NOQA
28-
DataException,
29-
FormulaException,
26+
from causalpy.data_validation import (
27+
PrePostFitDataValidator,
28+
DiDDataValidator,
29+
RDDataValidator,
30+
RegressionKinkDataValidator,
31+
PrePostNEGDDataValidator,
32+
IVDataValidator,
3033
)
3134
from causalpy.plot_utils import plot_xY
32-
from causalpy.utils import _is_variable_dummy_coded, round_num
35+
from causalpy.utils import round_num
3336

3437
LEGEND_FONT_SIZE = 12
3538
az.style.use("arviz-darkgrid")
@@ -108,7 +111,7 @@ def print_coefficients(self, round_to=None) -> None:
108111
)
109112

110113

111-
class PrePostFit(ExperimentalDesign):
114+
class PrePostFit(ExperimentalDesign, PrePostFitDataValidator):
112115
"""
113116
A class to analyse quasi-experiments where parameter estimation is based on just
114117
the pre-intervention data.
@@ -160,7 +163,6 @@ def __init__(
160163
) -> None:
161164
super().__init__(model=model, **kwargs)
162165
self._input_validation(data, treatment_time)
163-
164166
self.treatment_time = treatment_time
165167
# set experiment type - usually done in subclasses
166168
self.expt_type = "Pre-Post Fit"
@@ -214,21 +216,6 @@ def __init__(
214216
# cumulative impact post
215217
self.post_impact_cumulative = self.post_impact.cumsum(dim="obs_ind")
216218

217-
def _input_validation(self, data, treatment_time):
218-
"""Validate the input data and model formula for correctness"""
219-
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
220-
treatment_time, pd.Timestamp
221-
):
222-
raise BadIndexException(
223-
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
224-
)
225-
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
226-
treatment_time, pd.Timestamp
227-
):
228-
raise BadIndexException(
229-
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
230-
)
231-
232219
def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
233220
"""
234221
Plot the results
@@ -438,7 +425,7 @@ def plot(self, plot_predictors=False, **kwargs):
438425
return fig, ax
439426

440427

441-
class DifferenceInDifferences(ExperimentalDesign):
428+
class DifferenceInDifferences(ExperimentalDesign, DiDDataValidator):
442429
"""A class to analyse data from Difference in Difference settings.
443430
444431
.. note::
@@ -568,29 +555,6 @@ def __init__(
568555
if "post_treatment" in label and self.group_variable_name in label:
569556
self.causal_impact = self.idata.posterior["beta"].isel({"coeffs": i})
570557

571-
def _input_validation(self):
572-
"""Validate the input data and model formula for correctness"""
573-
if "post_treatment" not in self.formula:
574-
raise FormulaException(
575-
"A predictor called `post_treatment` should be in the formula"
576-
)
577-
578-
if "post_treatment" not in self.data.columns:
579-
raise DataException(
580-
"Require a boolean column labelling observations which are `treated`"
581-
)
582-
583-
if "unit" not in self.data.columns:
584-
raise DataException(
585-
"Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501
586-
)
587-
588-
if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False:
589-
raise DataException(
590-
f"""The grouping variable {self.group_variable_name} should be dummy
591-
coded. Consisting of 0's and 1's only."""
592-
)
593-
594558
def plot(self, round_to=None):
595559
"""Plot the results.
596560
@@ -749,7 +713,7 @@ def summary(self, round_to=None) -> None:
749713
self.print_coefficients(round_to)
750714

751715

752-
class RegressionDiscontinuity(ExperimentalDesign):
716+
class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator):
753717
"""
754718
A class to analyse sharp regression discontinuity experiments.
755719
@@ -876,18 +840,6 @@ def __init__(
876840
- self.pred_discon["posterior_predictive"].sel(obs_ind=0)["mu"]
877841
)
878842

879-
def _input_validation(self):
880-
"""Validate the input data and model formula for correctness"""
881-
if "treated" not in self.formula:
882-
raise FormulaException(
883-
"A predictor called `treated` should be in the formula"
884-
)
885-
886-
if _is_variable_dummy_coded(self.data["treated"]) is False:
887-
raise DataException(
888-
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
889-
)
890-
891843
def _is_treated(self, x):
892844
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.
893845
@@ -970,7 +922,7 @@ def summary(self, round_to: None) -> None:
970922
self.print_coefficients(round_to)
971923

972924

973-
class RegressionKink(ExperimentalDesign):
925+
class RegressionKink(ExperimentalDesign, RegressionKinkDataValidator):
974926
"""
975927
A class to analyse sharp regression kink experiments.
976928
@@ -1095,24 +1047,6 @@ def _probe_kink_point(self):
10951047
mu_kink_right = predicted["posterior_predictive"].sel(obs_ind=2)["mu"]
10961048
return mu_kink_left, mu_kink, mu_kink_right
10971049

1098-
def _input_validation(self):
1099-
"""Validate the input data and model formula for correctness"""
1100-
if "treated" not in self.formula:
1101-
raise FormulaException(
1102-
"A predictor called `treated` should be in the formula"
1103-
)
1104-
1105-
if _is_variable_dummy_coded(self.data["treated"]) is False:
1106-
raise DataException(
1107-
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
1108-
)
1109-
1110-
if self.bandwidth <= 0:
1111-
raise ValueError("The bandwidth must be greater than zero.")
1112-
1113-
if self.epsilon <= 0:
1114-
raise ValueError("Epsilon must be greater than zero.")
1115-
11161050
def _is_treated(self, x):
11171051
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
11181052
return np.greater_equal(x, self.kink_point)
@@ -1193,7 +1127,7 @@ def summary(self, round_to=None) -> None:
11931127
self.print_coefficients(round_to)
11941128

11951129

1196-
class PrePostNEGD(ExperimentalDesign):
1130+
class PrePostNEGD(ExperimentalDesign, PrePostNEGDDataValidator):
11971131
"""
11981132
A class to analyse data from pretest/posttest designs
11991133
@@ -1302,18 +1236,6 @@ def __init__(
13021236
{"coeffs": self._get_treatment_effect_coeff()}
13031237
)
13041238

1305-
# ================================================================
1306-
1307-
def _input_validation(self) -> None:
1308-
"""Validate the input data and model formula for correctness"""
1309-
if not _is_variable_dummy_coded(self.data[self.group_variable_name]):
1310-
raise DataException(
1311-
f"""
1312-
There must be 2 levels of the grouping variable
1313-
{self.group_variable_name}. I.e. the treated and untreated.
1314-
"""
1315-
)
1316-
13171239
def plot(self, round_to=None):
13181240
"""Plot the results
13191241
@@ -1408,7 +1330,7 @@ def _get_treatment_effect_coeff(self) -> str:
14081330
raise NameError("Unable to find coefficient name for the treatment effect")
14091331

14101332

1411-
class InstrumentalVariable(ExperimentalDesign):
1333+
class InstrumentalVariable(ExperimentalDesign, IVDataValidator):
14121334
"""
14131335
A class to analyse instrumental variable style experiments.
14141336
@@ -1555,26 +1477,3 @@ def get_naive_OLS_fit(self):
15551477
beta_params.insert(0, ols_reg.intercept_[0])
15561478
self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params))
15571479
self.ols_reg = ols_reg
1558-
1559-
def _input_validation(self):
1560-
"""Validate the input data and model formula for correctness"""
1561-
treatment = self.instruments_formula.split("~")[0]
1562-
test = treatment.strip() in self.instruments_data.columns
1563-
test = test & (treatment.strip() in self.data.columns)
1564-
if not test:
1565-
raise DataException(
1566-
f"""
1567-
The treatment variable:
1568-
{treatment} must appear in the instrument_data to be used
1569-
as an outcome variable and in the data object to be used as a covariate.
1570-
"""
1571-
)
1572-
Z = self.data[treatment.strip()]
1573-
check_binary = len(np.unique(Z)) > 2
1574-
if check_binary:
1575-
warnings.warn(
1576-
"""Warning. The treatment variable is not Binary.
1577-
This is not necessarily a problem but it violates
1578-
the assumption of a simple IV experiment.
1579-
The coefficients should be interpreted appropriately."""
1580-
)

0 commit comments

Comments
 (0)