Skip to content

Commit 6574481

Browse files
committed
add data validation checks for scikit-learn based experiments
1 parent cdbbd03 commit 6574481

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

causalpy/tests/test_input_validation.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ def test_did_validation_post_treatment_formula():
3535
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
3636
)
3737

38+
with pytest.raises(FormulaException):
39+
_ = cp.skl_experiments.DifferenceInDifferences(
40+
df,
41+
formula="y ~ 1 + group*post_SOMETHING",
42+
time_variable_name="t",
43+
group_variable_name="group",
44+
treated=1,
45+
untreated=0,
46+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
47+
)
48+
3849

3950
def test_did_validation_post_treatment_data():
4051
"""Test that we get a DataException if do not include post_treatment in the data"""
@@ -57,6 +68,17 @@ def test_did_validation_post_treatment_data():
5768
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
5869
)
5970

71+
with pytest.raises(DataException):
72+
_ = cp.skl_experiments.DifferenceInDifferences(
73+
df,
74+
formula="y ~ 1 + group*post_treatment",
75+
time_variable_name="t",
76+
group_variable_name="group",
77+
treated=1,
78+
untreated=0,
79+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
80+
)
81+
6082

6183
def test_did_validation_unit_data():
6284
"""Test that we get a DataException if do not include unit in the data"""
@@ -79,6 +101,17 @@ def test_did_validation_unit_data():
79101
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
80102
)
81103

104+
with pytest.raises(DataException):
105+
_ = cp.skl_experiments.DifferenceInDifferences(
106+
df,
107+
formula="y ~ 1 + group*post_treatment",
108+
time_variable_name="t",
109+
group_variable_name="group",
110+
treated=1,
111+
untreated=0,
112+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
113+
)
114+
82115

83116
def test_did_validation_group_dummy_coded():
84117
"""Test that we get a DataException if the group variable is not dummy coded"""
@@ -101,6 +134,17 @@ def test_did_validation_group_dummy_coded():
101134
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
102135
)
103136

137+
with pytest.raises(DataException):
138+
_ = cp.skl_experiments.DifferenceInDifferences(
139+
df,
140+
formula="y ~ 1 + group*post_treatment",
141+
time_variable_name="t",
142+
group_variable_name="group",
143+
treated=1,
144+
untreated=0,
145+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
146+
)
147+
104148

105149
# Synthetic Control
106150

@@ -118,6 +162,16 @@ def test_sc_input_error():
118162
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
119163
)
120164

165+
with pytest.raises(BadIndexException):
166+
df = cp.load_data("sc")
167+
treatment_time = pd.to_datetime("2016 June 24")
168+
_ = cp.skl_experiments.SyntheticControl(
169+
df,
170+
treatment_time,
171+
formula="actual ~ 0 + a + b + c + d + e + f + g",
172+
model=cp.skl_models.WeightedProportion(),
173+
)
174+
121175

122176
def test_sc_brexit_input_error():
123177
"""Confirm a BadIndexException is raised if the data index is datetime and the
@@ -187,6 +241,16 @@ def test_rd_validation_treated_in_formula():
187241
treatment_threshold=0.5,
188242
)
189243

244+
with pytest.raises(FormulaException):
245+
from sklearn.linear_model import LinearRegression
246+
247+
_ = cp.skl_experiments.RegressionDiscontinuity(
248+
df,
249+
formula="y ~ 1 + x",
250+
model=LinearRegression(),
251+
treatment_threshold=0.5,
252+
)
253+
190254

191255
def test_rd_validation_treated_is_dummy():
192256
"""Test that we get a DataException if treated is not dummy coded"""
@@ -206,6 +270,16 @@ def test_rd_validation_treated_is_dummy():
206270
treatment_threshold=0.5,
207271
)
208272

273+
from sklearn.linear_model import LinearRegression
274+
275+
with pytest.raises(DataException):
276+
_ = cp.skl_experiments.RegressionDiscontinuity(
277+
df,
278+
formula="y ~ 1 + x + treated",
279+
model=LinearRegression(),
280+
treatment_threshold=0.5,
281+
)
282+
209283

210284
def test_iv_treatment_var_is_present():
211285
"""Test the treatment variable is present for Instrumental Variable experiment"""

0 commit comments

Comments
 (0)