@@ -35,6 +35,17 @@ def test_did_validation_post_treatment_formula():
35
35
model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
36
36
)
37
37
38
+ with pytest .raises (FormulaException ):
39
+ _ = cp .skl_experiments .DifferenceInDifferences (
40
+ df ,
41
+ formula = "y ~ 1 + group*post_SOMETHING" ,
42
+ time_variable_name = "t" ,
43
+ group_variable_name = "group" ,
44
+ treated = 1 ,
45
+ untreated = 0 ,
46
+ model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
47
+ )
48
+
38
49
39
50
def test_did_validation_post_treatment_data ():
40
51
"""Test that we get a DataException if do not include post_treatment in the data"""
@@ -57,6 +68,17 @@ def test_did_validation_post_treatment_data():
57
68
model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
58
69
)
59
70
71
+ with pytest .raises (DataException ):
72
+ _ = cp .skl_experiments .DifferenceInDifferences (
73
+ df ,
74
+ formula = "y ~ 1 + group*post_treatment" ,
75
+ time_variable_name = "t" ,
76
+ group_variable_name = "group" ,
77
+ treated = 1 ,
78
+ untreated = 0 ,
79
+ model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
80
+ )
81
+
60
82
61
83
def test_did_validation_unit_data ():
62
84
"""Test that we get a DataException if do not include unit in the data"""
@@ -79,6 +101,17 @@ def test_did_validation_unit_data():
79
101
model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
80
102
)
81
103
104
+ with pytest .raises (DataException ):
105
+ _ = cp .skl_experiments .DifferenceInDifferences (
106
+ df ,
107
+ formula = "y ~ 1 + group*post_treatment" ,
108
+ time_variable_name = "t" ,
109
+ group_variable_name = "group" ,
110
+ treated = 1 ,
111
+ untreated = 0 ,
112
+ model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
113
+ )
114
+
82
115
83
116
def test_did_validation_group_dummy_coded ():
84
117
"""Test that we get a DataException if the group variable is not dummy coded"""
@@ -101,6 +134,17 @@ def test_did_validation_group_dummy_coded():
101
134
model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
102
135
)
103
136
137
+ with pytest .raises (DataException ):
138
+ _ = cp .skl_experiments .DifferenceInDifferences (
139
+ df ,
140
+ formula = "y ~ 1 + group*post_treatment" ,
141
+ time_variable_name = "t" ,
142
+ group_variable_name = "group" ,
143
+ treated = 1 ,
144
+ untreated = 0 ,
145
+ model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
146
+ )
147
+
104
148
105
149
# Synthetic Control
106
150
@@ -118,6 +162,16 @@ def test_sc_input_error():
118
162
model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs ),
119
163
)
120
164
165
+ with pytest .raises (BadIndexException ):
166
+ df = cp .load_data ("sc" )
167
+ treatment_time = pd .to_datetime ("2016 June 24" )
168
+ _ = cp .skl_experiments .SyntheticControl (
169
+ df ,
170
+ treatment_time ,
171
+ formula = "actual ~ 0 + a + b + c + d + e + f + g" ,
172
+ model = cp .skl_models .WeightedProportion (),
173
+ )
174
+
121
175
122
176
def test_sc_brexit_input_error ():
123
177
"""Confirm a BadIndexException is raised if the data index is datetime and the
@@ -187,6 +241,16 @@ def test_rd_validation_treated_in_formula():
187
241
treatment_threshold = 0.5 ,
188
242
)
189
243
244
+ with pytest .raises (FormulaException ):
245
+ from sklearn .linear_model import LinearRegression
246
+
247
+ _ = cp .skl_experiments .RegressionDiscontinuity (
248
+ df ,
249
+ formula = "y ~ 1 + x" ,
250
+ model = LinearRegression (),
251
+ treatment_threshold = 0.5 ,
252
+ )
253
+
190
254
191
255
def test_rd_validation_treated_is_dummy ():
192
256
"""Test that we get a DataException if treated is not dummy coded"""
@@ -206,6 +270,16 @@ def test_rd_validation_treated_is_dummy():
206
270
treatment_threshold = 0.5 ,
207
271
)
208
272
273
+ from sklearn .linear_model import LinearRegression
274
+
275
+ with pytest .raises (DataException ):
276
+ _ = cp .skl_experiments .RegressionDiscontinuity (
277
+ df ,
278
+ formula = "y ~ 1 + x + treated" ,
279
+ model = LinearRegression (),
280
+ treatment_threshold = 0.5 ,
281
+ )
282
+
209
283
210
284
def test_iv_treatment_var_is_present ():
211
285
"""Test the treatment variable is present for Instrumental Variable experiment"""
0 commit comments