Skip to content

Commit fdce5b0

Browse files
authored
Merge pull request #489 from williambdean/mock-sample
Use mock sample in tests
2 parents 8b26d42 + c3b6423 commit fdce5b0

File tree

7 files changed

+37
-24
lines changed

7 files changed

+37
-24
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- name: Setup environment
2525
run: pip install -e .[test]
2626
- name: Run doctests
27-
run: pytest --doctest-modules --ignore=causalpy/tests/ causalpy/
27+
run: pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
2828
- name: Run extra tests
2929
run: pytest docs/source/.codespell/test_notebook_to_markdown.py
3030
- name: Run tests

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY: init lint check_lint test uml html cleandocs
1+
.PHONY: init lint check_lint test uml html cleandocs doctest
22

33
init:
44
python -m pip install -e . --no-deps
@@ -13,7 +13,7 @@ check_lint:
1313
interrogate .
1414

1515
doctest:
16-
pytest --doctest-modules --ignore=causalpy/tests/ causalpy/
16+
pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
1717

1818
test:
1919
pytest

causalpy/experiments/prepostnegd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class PrePostNEGD(BaseExperiment):
7272
... }
7373
... ),
7474
... )
75-
>>> result.summary(round_to=1) # doctest: +NUMBER
75+
>>> result.summary(round_to=1) # doctest: +SKIP
7676
==================Pretest/posttest Nonequivalent Group Design===================
7777
Formula: post ~ 1 + C(group) + pre
7878
<BLANKLINE>

causalpy/tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,24 @@
2020

2121
import numpy as np
2222
import pytest
23+
from pymc.testing import mock_sample, mock_sample_setup_and_teardown
2324

2425

2526
@pytest.fixture(scope="session")
2627
def rng() -> np.random.Generator:
2728
"""Random number generator that can persist through a pytest session"""
2829
seed: int = sum(map(ord, "causalpy"))
2930
return np.random.default_rng(seed=seed)
31+
32+
33+
mock_pymc_sample = pytest.fixture(mock_sample_setup_and_teardown, scope="session")
34+
35+
36+
@pytest.fixture(autouse=True)
37+
def mock_sample_for_doctest(request):
38+
if not request.config.getoption("--doctest-modules", default=False):
39+
return
40+
41+
import pymc as pm
42+
43+
pm.sample = mock_sample

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
@pytest.mark.integration
26-
def test_did():
26+
def test_did(mock_pymc_sample):
2727
"""
2828
Test Difference in Differences (DID) PyMC experiment.
2929
@@ -57,7 +57,7 @@ def test_did():
5757

5858

5959
@pytest.mark.integration
60-
def test_did_banks_simple():
60+
def test_did_banks_simple(mock_pymc_sample):
6161
"""
6262
Test simple Differences In Differences Experiment on the 'banks' data set.
6363
@@ -113,7 +113,7 @@ def test_did_banks_simple():
113113

114114

115115
@pytest.mark.integration
116-
def test_did_banks_multi():
116+
def test_did_banks_multi(mock_pymc_sample):
117117
"""
118118
Test multiple regression Differences In Differences Experiment on the 'banks'
119119
data set.
@@ -168,7 +168,7 @@ def test_did_banks_multi():
168168

169169

170170
@pytest.mark.integration
171-
def test_rd():
171+
def test_rd(mock_pymc_sample):
172172
"""
173173
Test Regression Discontinuity experiment.
174174
@@ -199,7 +199,7 @@ def test_rd():
199199

200200

201201
@pytest.mark.integration
202-
def test_rd_bandwidth():
202+
def test_rd_bandwidth(mock_pymc_sample):
203203
"""
204204
Test Regression Discontinuity experiment with bandwidth parameter.
205205
@@ -229,7 +229,7 @@ def test_rd_bandwidth():
229229

230230

231231
@pytest.mark.integration
232-
def test_rd_drinking():
232+
def test_rd_drinking(mock_pymc_sample):
233233
"""
234234
Test Regression Discontinuity experiment on drinking age data.
235235
@@ -289,7 +289,7 @@ def reg_kink_function(x, beta, kink):
289289

290290

291291
@pytest.mark.integration
292-
def test_rkink():
292+
def test_rkink(mock_pymc_sample):
293293
"""
294294
Test Regression Kink design.
295295
@@ -320,7 +320,7 @@ def test_rkink():
320320

321321

322322
@pytest.mark.integration
323-
def test_rkink_bandwidth():
323+
def test_rkink_bandwidth(mock_pymc_sample):
324324
"""
325325
Test Regression Kink experiment with bandwidth parameter.
326326
@@ -350,7 +350,7 @@ def test_rkink_bandwidth():
350350

351351

352352
@pytest.mark.integration
353-
def test_its():
353+
def test_its(mock_pymc_sample):
354354
"""
355355
Test Interrupted Time-Series experiment.
356356
@@ -403,7 +403,7 @@ def test_its():
403403

404404

405405
@pytest.mark.integration
406-
def test_its_covid():
406+
def test_its_covid(mock_pymc_sample):
407407
"""
408408
Test Interrupted Time-Series experiment on COVID data.
409409
@@ -457,7 +457,7 @@ def test_its_covid():
457457

458458

459459
@pytest.mark.integration
460-
def test_sc():
460+
def test_sc(mock_pymc_sample):
461461
"""
462462
Test Synthetic Control experiment.
463463
@@ -516,7 +516,7 @@ def test_sc():
516516

517517

518518
@pytest.mark.integration
519-
def test_sc_brexit():
519+
def test_sc_brexit(mock_pymc_sample):
520520
"""
521521
Test Synthetic Control experiment on Brexit data.
522522
@@ -579,7 +579,7 @@ def test_sc_brexit():
579579

580580

581581
@pytest.mark.integration
582-
def test_ancova():
582+
def test_ancova(mock_pymc_sample):
583583
"""
584584
Test Pre-PostNEGD experiment on anova1 data.
585585
@@ -611,7 +611,7 @@ def test_ancova():
611611

612612

613613
@pytest.mark.integration
614-
def test_geolift1():
614+
def test_geolift1(mock_pymc_sample):
615615
"""
616616
Test Synthetic Control experiment on geo lift data.
617617
@@ -648,7 +648,7 @@ def test_geolift1():
648648

649649

650650
@pytest.mark.integration
651-
def test_iv_reg():
651+
def test_iv_reg(mock_pymc_sample):
652652
df = cp.load_data("risk")
653653
instruments_formula = "risk ~ 1 + logmort0"
654654
formula = "loggdp ~ 1 + risk"
@@ -676,7 +676,7 @@ def test_iv_reg():
676676

677677

678678
@pytest.mark.integration
679-
def test_inverse_prop():
679+
def test_inverse_prop(mock_pymc_sample):
680680
"""Test the InversePropensityWeighting class."""
681681
df = cp.load_data("nhefs")
682682
sample_kwargs = {

causalpy/tests/test_pymc_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_fit_build_not_implemented(self):
9393
argvalues=[None, {"a": 1}],
9494
ids=["None-coords", "dict-coords"],
9595
)
96-
def test_fit_predict(self, coords, rng) -> None:
96+
def test_fit_predict(self, coords, rng, mock_pymc_sample) -> None:
9797
"""
9898
Test fit and predict methods on MyToyModel.
9999
@@ -122,7 +122,7 @@ def test_fit_predict(self, coords, rng) -> None:
122122
assert isinstance(predictions, az.InferenceData)
123123

124124

125-
def test_idata_property():
125+
def test_idata_property(mock_pymc_sample):
126126
"""Test that we can access the idata property of the model"""
127127
df = cp.load_data("did")
128128
result = cp.DifferenceInDifferences(
@@ -140,7 +140,7 @@ def test_idata_property():
140140

141141

142142
@pytest.mark.parametrize("seed", seeds)
143-
def test_result_reproducibility(seed):
143+
def test_result_reproducibility(seed, mock_pymc_sample):
144144
"""Test that we can reproduce the results from the model. We could in theory test
145145
this with all the model and experiment types, but what is being targeted is
146146
the PyMCModel.fit method, so we should be safe testing with just one model. Here

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ addopts = [
8484
"--strict-config",
8585
"--cov=causalpy",
8686
"--cov-report=term-missing",
87-
"--doctest-modules",
8887
]
8988
testpaths = "causalpy/tests"
9089
markers = [

0 commit comments

Comments
 (0)