Skip to content

Commit 2ad1ba3

Browse files
committed
test exceptions are raised for the new input validation
1 parent dab61dc commit 2ad1ba3

File tree

4 files changed

+120
-44
lines changed

4 files changed

+120
-44
lines changed

causalpy/pymc_experiments.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313

1414
import warnings
15-
from typing import Optional, Union
15+
from typing import Union
1616

1717
import arviz as az
1818
import matplotlib.pyplot as plt
@@ -794,7 +794,7 @@ def __init__(
794794
model=None,
795795
running_variable_name: str = "x",
796796
epsilon: float = 0.001,
797-
bandwidth: Optional[float] = None,
797+
bandwidth: float = np.inf,
798798
**kwargs,
799799
):
800800
super().__init__(model=model, **kwargs)
@@ -807,7 +807,7 @@ def __init__(
807807
self.bandwidth = bandwidth
808808
self._input_validation()
809809

810-
if self.bandwidth is not None:
810+
if self.bandwidth is not np.inf:
811811
fmin = self.treatment_threshold - self.bandwidth
812812
fmax = self.treatment_threshold + self.bandwidth
813813
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
@@ -836,7 +836,7 @@ def __init__(
836836
self.score = self.model.score(X=self.X, y=self.y)
837837

838838
# get the model predictions of the observed data
839-
if self.bandwidth is not None:
839+
if self.bandwidth is not np.inf:
840840
xi = np.linspace(fmin, fmax, 200)
841841
else:
842842
xi = np.linspace(
@@ -988,7 +988,7 @@ def __init__(
988988
model=None,
989989
running_variable_name: str = "x",
990990
epsilon: float = 0.001,
991-
bandwidth: Optional[float] = None,
991+
bandwidth: float = np.inf,
992992
**kwargs,
993993
):
994994
super().__init__(model=model, **kwargs)
@@ -1001,7 +1001,7 @@ def __init__(
10011001
self.bandwidth = bandwidth
10021002
self._input_validation()
10031003

1004-
if self.bandwidth is not None:
1004+
if self.bandwidth is not np.inf:
10051005
fmin = self.kink_point - self.bandwidth
10061006
fmax = self.kink_point + self.bandwidth
10071007
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
@@ -1027,7 +1027,7 @@ def __init__(
10271027
self.score = self.model.score(X=self.X, y=self.y)
10281028

10291029
# get the model predictions of the observed data
1030-
if self.bandwidth is not None:
1030+
if self.bandwidth is not np.inf:
10311031
xi = np.linspace(fmin, fmax, 200)
10321032
else:
10331033
xi = np.linspace(

causalpy/tests/test_input_validation.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Input validation tests"""
22

3+
import numpy as np
34
import pandas as pd
45
import pytest
56

@@ -223,3 +224,83 @@ def test_iv_treatment_var_is_present():
223224
sample_kwargs=sample_kwargs
224225
),
225226
)
227+
228+
229+
# Regression kink design
230+
231+
232+
def setup_regression_kink_data(kink):
233+
"""Set up data for regression kink design tests"""
234+
# define parameters for data generation
235+
seed = 42
236+
rng = np.random.default_rng(seed)
237+
N = 50
238+
kink = 0.5
239+
beta = [0, -1, 0, 2, 0]
240+
sigma = 0.05
241+
# generate data
242+
x = rng.uniform(-1, 1, N)
243+
y = reg_kink_function(x, beta, kink) + rng.normal(0, sigma, N)
244+
return pd.DataFrame({"x": x, "y": y, "treated": x >= kink})
245+
246+
247+
def reg_kink_function(x, beta, kink):
248+
"""Utility function for regression kink design. Returns a piecewise linear function
249+
evaluated at x with a kink at kink and parameters beta"""
250+
return (
251+
beta[0]
252+
+ beta[1] * x
253+
+ beta[2] * x**2
254+
+ beta[3] * (x - kink) * (x >= kink)
255+
+ beta[4] * (x - kink) ** 2 * (x >= kink)
256+
)
257+
258+
259+
def test_rkink_bandwidth_check():
260+
"""Test that we get exceptions when bandwidth parameter is <= 0"""
261+
with pytest.raises(ValueError):
262+
kink = 0.5
263+
df = setup_regression_kink_data(kink)
264+
_ = cp.pymc_experiments.RegressionKink(
265+
df,
266+
formula=f"y ~ 1 + x + I((x-{kink})*treated)",
267+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
268+
kink_point=kink,
269+
bandwidth=0,
270+
)
271+
272+
with pytest.raises(ValueError):
273+
kink = 0.5
274+
df = setup_regression_kink_data(kink)
275+
_ = cp.pymc_experiments.RegressionKink(
276+
df,
277+
formula=f"y ~ 1 + x + I((x-{kink})*treated)",
278+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
279+
kink_point=kink,
280+
bandwidth=-1,
281+
)
282+
283+
284+
def test_rkink_epsilon_check():
285+
"""Test that we get exceptions when epsilon parameter is <= 0"""
286+
with pytest.raises(ValueError):
287+
kink = 0.5
288+
df = setup_regression_kink_data(kink)
289+
_ = cp.pymc_experiments.RegressionKink(
290+
df,
291+
formula=f"y ~ 1 + x + I((x-{kink})*treated)",
292+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
293+
kink_point=kink,
294+
epsilon=0,
295+
)
296+
297+
with pytest.raises(ValueError):
298+
kink = 0.5
299+
df = setup_regression_kink_data(kink)
300+
_ = cp.pymc_experiments.RegressionKink(
301+
df,
302+
formula=f"y ~ 1 + x + I((x-{kink})*treated)",
303+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
304+
kink_point=kink,
305+
epsilon=-1,
306+
)

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,6 @@
77
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
88

99

10-
def reg_kink_function(x, beta, kink):
11-
"""Utility function for regression kink design. Returns a piecewise linear function
12-
evaluated at x with a kink at kink and parameters beta"""
13-
return (
14-
beta[0]
15-
+ beta[1] * x
16-
+ beta[2] * x**2
17-
+ beta[3] * (x - kink) * (x >= kink)
18-
+ beta[4] * (x - kink) ** 2 * (x >= kink)
19-
)
20-
21-
2210
@pytest.mark.integration
2311
def test_did():
2412
"""
@@ -230,6 +218,33 @@ def test_rd_drinking():
230218
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
231219

232220

221+
def setup_regression_kink_data(kink):
222+
"""Set up data for regression kink design tests"""
223+
# define parameters for data generation
224+
seed = 42
225+
rng = np.random.default_rng(seed)
226+
N = 50
227+
kink = 0.5
228+
beta = [0, -1, 0, 2, 0]
229+
sigma = 0.05
230+
# generate data
231+
x = rng.uniform(-1, 1, N)
232+
y = reg_kink_function(x, beta, kink) + rng.normal(0, sigma, N)
233+
return pd.DataFrame({"x": x, "y": y, "treated": x >= kink})
234+
235+
236+
def reg_kink_function(x, beta, kink):
237+
"""Utility function for regression kink design. Returns a piecewise linear function
238+
evaluated at x with a kink at kink and parameters beta"""
239+
return (
240+
beta[0]
241+
+ beta[1] * x
242+
+ beta[2] * x**2
243+
+ beta[3] * (x - kink) * (x >= kink)
244+
+ beta[4] * (x - kink) ** 2 * (x >= kink)
245+
)
246+
247+
233248
@pytest.mark.integration
234249
def test_rkink():
235250
"""
@@ -241,18 +256,8 @@ def test_rkink():
241256
3. the correct number of MCMC chains exists in the posterior inference data
242257
4. the correct number of MCMC draws exists in the posterior inference data
243258
"""
244-
# define parameters for data generation
245-
seed = 42
246-
rng = np.random.default_rng(seed)
247-
N = 50
248259
kink = 0.5
249-
beta = [0, -1, 0, 2, 0]
250-
sigma = 0.05
251-
# generate data
252-
x = rng.uniform(-1, 1, N)
253-
y = reg_kink_function(x, beta, kink) + rng.normal(0, sigma, N)
254-
df = pd.DataFrame({"x": x, "y": y, "treated": x >= kink})
255-
# run experiment
260+
df = setup_regression_kink_data(kink)
256261
result = cp.pymc_experiments.RegressionKink(
257262
df,
258263
formula=f"y ~ 1 + x + I((x-{kink})*treated)",
@@ -276,18 +281,8 @@ def test_rkink_bandwidth():
276281
3. the correct number of MCMC chains exists in the posterior inference data
277282
4. the correct number of MCMC draws exists in the posterior inference data
278283
"""
279-
# define parameters for data generation
280-
seed = 42
281-
rng = np.random.default_rng(seed)
282-
N = 50
283284
kink = 0.5
284-
beta = [0, -1, 0, 2, 0]
285-
sigma = 0.05
286-
# generate data
287-
x = rng.uniform(-1, 1, N)
288-
y = reg_kink_function(x, beta, kink) + rng.normal(0, sigma, N)
289-
df = pd.DataFrame({"x": x, "y": y, "treated": x >= kink})
290-
# run experiment
285+
df = setup_regression_kink_data(kink)
291286
result = cp.pymc_experiments.RegressionKink(
292287
df,
293288
formula=f"y ~ 1 + x + I((x-{kink})*treated)",

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)