Skip to content

Commit 92b7218

Browse files
committed
Merge branch 'main' into fix-docs
2 parents c77d917 + e55f23b commit 92b7218

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+8746
-8223
lines changed

CONTRIBUTING.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ We recommend that your contribution complies with the following guidelines befor
128128
make doctest
129129
```
130130

131-
- Doctest can also be run directly via pytest, which can be helpful to run only specific tests during development. The following commands run all doctests, only doctests in the pymc_models module, and only the doctests for the `ModelBuilder` class in pymc_models:
131+
- Doctest can also be run directly via pytest, which can be helpful to run only specific tests during development. The following commands run all doctests, only doctests in the pymc_models module, and only the doctests for the `PyMCModel` class in pymc_models:
132132

133133
```bash
134134
pytest --doctest-modules causalpy/
135135
pytest --doctest-modules causalpy/pymc_models.py
136-
pytest --doctest-modules causalpy/pmyc_models.py::causalpy.pymc_models.ModelBuilder
136+
pytest --doctest-modules causalpy/pmyc_models.py::causalpy.pymc_models.PyMCModel
137137
```
138138

139139
- To indicate a work in progress please mark the PR as `draft`. Drafts may be useful to (1) indicate you are working on something to avoid duplicated work, (2) request broad review of functionality or API, or (3) seek collaborators.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ df = (
4444
)
4545

4646
# Run the analysis
47-
result = cp.pymc_experiments.RegressionDiscontinuity(
47+
result = cp.RegressionDiscontinuity(
4848
df,
4949
formula="all ~ 1 + age + treated",
5050
running_variable_name="age",

causalpy/__init__.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,38 @@
1313
# limitations under the License.
1414
import arviz as az
1515

16-
from causalpy import pymc_experiments, pymc_models, skl_experiments, skl_models
16+
import causalpy.pymc_experiments as pymc_experiments # to be deprecated
17+
import causalpy.pymc_models as pymc_models
18+
import causalpy.skl_experiments as skl_experiments # to be deprecated
19+
import causalpy.skl_models as skl_models
20+
from causalpy.skl_models import create_causalpy_compatible_class
1721
from causalpy.version import __version__
1822

1923
from .data import load_data
24+
from .experiments.diff_in_diff import DifferenceInDifferences
25+
from .experiments.instrumental_variable import InstrumentalVariable
26+
from .experiments.inverse_propensity_weighting import InversePropensityWeighting
27+
from .experiments.prepostfit import InterruptedTimeSeries, SyntheticControl
28+
from .experiments.prepostnegd import PrePostNEGD
29+
from .experiments.regression_discontinuity import RegressionDiscontinuity
30+
from .experiments.regression_kink import RegressionKink
2031

2132
az.style.use("arviz-darkgrid")
2233

2334
__all__ = [
24-
"pymc_experiments",
35+
"__version__",
36+
"DifferenceInDifferences",
37+
"create_causalpy_compatible_class",
38+
"InstrumentalVariable",
39+
"InterruptedTimeSeries",
40+
"InversePropensityWeighting",
41+
"load_data",
42+
"PrePostNEGD",
43+
"pymc_experiments", # to be deprecated
2544
"pymc_models",
26-
"skl_experiments",
45+
"RegressionDiscontinuity",
46+
"RegressionKink",
47+
"skl_experiments", # to be deprecated
2748
"skl_models",
28-
"load_data",
29-
"__version__",
49+
"SyntheticControl",
3050
]

causalpy/data_validation.py

Lines changed: 0 additions & 174 deletions
This file was deleted.

causalpy/experiments/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

causalpy/experiments/base.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Base class for quasi experimental designs.
16+
"""
17+
18+
from abc import abstractmethod
19+
20+
from sklearn.base import RegressorMixin
21+
22+
from causalpy.pymc_models import PyMCModel
23+
from causalpy.skl_models import create_causalpy_compatible_class
24+
25+
26+
class BaseExperiment:
27+
"""Base class for quasi experimental designs."""
28+
29+
supports_bayes: bool
30+
supports_ols: bool
31+
32+
def __init__(self, model=None):
33+
# Ensure we've made any provided Scikit Learn model (as identified as being type
34+
# RegressorMixin) compatible with CausalPy by appending our custom methods.
35+
if isinstance(model, RegressorMixin):
36+
model = create_causalpy_compatible_class(model)
37+
38+
if model is not None:
39+
self.model = model
40+
41+
if isinstance(self.model, PyMCModel) and not self.supports_bayes:
42+
raise ValueError("Bayesian models not supported.")
43+
44+
if isinstance(self.model, RegressorMixin) and not self.supports_ols:
45+
raise ValueError("OLS models not supported.")
46+
47+
if self.model is None:
48+
raise ValueError("model not set or passed.")
49+
50+
@property
51+
def idata(self):
52+
"""Return the InferenceData object of the model. Only relevant for PyMC models."""
53+
return self.model.idata
54+
55+
def print_coefficients(self, round_to=None):
56+
"""Ask the model to print its coefficients."""
57+
self.model.print_coefficients(self.labels, round_to)
58+
59+
def plot(self, *args, **kwargs) -> tuple:
60+
"""Plot the model.
61+
62+
Internally, this function dispatches to either `bayesian_plot` or `ols_plot`
63+
depending on the model type.
64+
"""
65+
if isinstance(self.model, PyMCModel):
66+
return self.bayesian_plot(*args, **kwargs)
67+
elif isinstance(self.model, RegressorMixin):
68+
return self.ols_plot(*args, **kwargs)
69+
else:
70+
raise ValueError("Unsupported model type")
71+
72+
@abstractmethod
73+
def bayesian_plot(self, *args, **kwargs):
74+
"""Abstract method for plotting the model."""
75+
raise NotImplementedError("bayesian_plot method not yet implemented")
76+
77+
@abstractmethod
78+
def ols_plot(self, *args, **kwargs):
79+
"""Abstract method for plotting the model."""
80+
raise NotImplementedError("ols_plot method not yet implemented")

0 commit comments

Comments
 (0)