Skip to content

Commit e49b005

Browse files
committed
Merge branch 'main' into inv_propensity_model
Signed-off-by: Nathaniel <NathanielF@users.noreply.github.com>
2 parents a863252 + 3dc2ffe commit e49b005

28 files changed

+858
-23
lines changed

.pre-commit-config.yaml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@ ci:
44
# See https://pre-commit.com for more information
55
# See https://pre-commit.com/hooks.html for more hooks
66
repos:
7+
- repo: https://github.com/lucianopaz/head_of_apache
8+
rev: "0.0.3"
9+
hooks:
10+
- id: head_of_apache
11+
args:
12+
- --author=The PyMC Labs Developers
13+
- --exclude=docs/
14+
- --exclude=scripts/
715
- repo: https://github.com/pre-commit/pre-commit-hooks
816
rev: v4.6.0
917
hooks:
@@ -16,7 +24,7 @@ repos:
1624
- id: check-added-large-files
1725
args: ["--maxkb=1500"]
1826
- repo: https://github.com/astral-sh/ruff-pre-commit
19-
rev: v0.4.2
27+
rev: v0.4.3
2028
hooks:
2129
- id: ruff
2230
args: ["--fix", "--output-format=full"]

causalpy/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
import arviz as az
215

316
from causalpy import pymc_experiments, pymc_models, skl_experiments, skl_models

causalpy/custom_exceptions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
"""
215
Custom Exceptions for CausalPy.
316
"""

causalpy/data/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
"""Code for loading datasets."""
215

316
from .datasets import load_data

causalpy/data/datasets.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
"""
215
Functions to load example datasets
316
"""

causalpy/data/simulate_data.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
"""
215
Functions that generate data sets used in examples
316
"""

causalpy/data_validation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
import warnings # noqa: I001
215

316
import pandas as pd

causalpy/plot_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
"""
215
Plotting utility functions.
316
"""

causalpy/pymc_experiments.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
"""
215
Experiment routines for PyMC models.
316
@@ -89,6 +102,11 @@ def print_coefficients(self, round_to=None) -> None:
89102
... "progressbar": False
90103
... }),
91104
... )
105+
<BLANKLINE>
106+
<BLANKLINE>
107+
<BLANKLINE>
108+
<BLANKLINE>
109+
<BLANKLINE>
92110
>>> result.print_coefficients(round_to=1) # doctest: +NUMBER
93111
Model coefficients:
94112
Intercept 1, 94% HDI [1, 1]
@@ -147,6 +165,11 @@ class PrePostFit(ExperimentalDesign, PrePostFitDataValidator):
147165
... }
148166
... ),
149167
... )
168+
<BLANKLINE>
169+
<BLANKLINE>
170+
<BLANKLINE>
171+
<BLANKLINE>
172+
<BLANKLINE>
150173
>>> result.summary(round_to=1) # doctest: +NUMBER
151174
==================================Pre-Post Fit==================================
152175
Formula: actual ~ 0 + a + g
@@ -373,6 +396,11 @@ class InterruptedTimeSeries(PrePostFit):
373396
... }
374397
... )
375398
... )
399+
<BLANKLINE>
400+
<BLANKLINE>
401+
<BLANKLINE>
402+
<BLANKLINE>
403+
<BLANKLINE>
376404
"""
377405

378406
expt_type = "Interrupted Time Series"
@@ -408,6 +436,11 @@ class SyntheticControl(PrePostFit):
408436
... }
409437
... ),
410438
... )
439+
<BLANKLINE>
440+
<BLANKLINE>
441+
<BLANKLINE>
442+
<BLANKLINE>
443+
<BLANKLINE>
411444
"""
412445

413446
expt_type = "Synthetic Control"
@@ -464,6 +497,11 @@ class DifferenceInDifferences(ExperimentalDesign, DiDDataValidator):
464497
... }
465498
... )
466499
... )
500+
<BLANKLINE>
501+
<BLANKLINE>
502+
<BLANKLINE>
503+
<BLANKLINE>
504+
<BLANKLINE>
467505
"""
468506

469507
def __init__(
@@ -755,6 +793,11 @@ class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator):
755793
... ),
756794
... treatment_threshold=0.5,
757795
... )
796+
<BLANKLINE>
797+
<BLANKLINE>
798+
<BLANKLINE>
799+
<BLANKLINE>
800+
<BLANKLINE>
758801
"""
759802

760803
def __init__(
@@ -1164,6 +1207,10 @@ class PrePostNEGD(ExperimentalDesign, PrePostNEGDDataValidator):
11641207
... }
11651208
... )
11661209
... )
1210+
<BLANKLINE>
1211+
<BLANKLINE>
1212+
<BLANKLINE>
1213+
<BLANKLINE>
11671214
>>> result.summary(round_to=1) # doctest: +NUMBER
11681215
==================Pretest/posttest Nonequivalent Group Design===================
11691216
Formula: post ~ 1 + C(group) + pre
@@ -1394,6 +1441,8 @@ class InstrumentalVariable(ExperimentalDesign, IVDataValidator):
13941441
... formula=formula,
13951442
... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs),
13961443
... )
1444+
<BLANKLINE>
1445+
<BLANKLINE>
13971446
"""
13981447

13991448
def __init__(

causalpy/pymc_models.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
"""
215
Defines generic PyMC ModelBuilder class and subclasses for
316
@@ -41,8 +54,8 @@ class ModelBuilder(pm.Model):
4154
>>> class MyToyModel(ModelBuilder):
4255
... def build_model(self, X, y, coords):
4356
... with self:
44-
... X_ = pm.MutableData(name="X", value=X)
45-
... y_ = pm.MutableData(name="y", value=y)
57+
... X_ = pm.Data(name="X", value=X)
58+
... y_ = pm.Data(name="y", value=y)
4659
... beta = pm.Normal("beta", mu=0, sigma=1, shape=X_.shape[1])
4760
... sigma = pm.HalfNormal("sigma", sigma=1)
4861
... mu = pm.Deterministic("mu", pm.math.dot(X_, beta))
@@ -59,13 +72,17 @@ class ModelBuilder(pm.Model):
5972
... }
6073
... )
6174
>>> model.fit(X, y)
62-
Inference...
75+
<BLANKLINE>
76+
<BLANKLINE>
77+
Inference data...
6378
>>> X_new = rng.normal(loc=0, scale=1, size=(20,2))
6479
>>> model.predict(X_new)
65-
Inference...
66-
>>> model.score(X, y) # doctest: +NUMBER
67-
r2 0.3
68-
r2_std 0.0
80+
<BLANKLINE>
81+
Inference data...
82+
>>> model.score(X, y)
83+
<BLANKLINE>
84+
r2 0.390344
85+
r2_std 0.081135
6986
dtype: float64
7087
"""
7188

@@ -99,10 +116,7 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
99116

100117
# Ensure random_seed is used in sample_prior_predictive() and
101118
# sample_posterior_predictive() if provided in sample_kwargs.
102-
if "random_seed" in self.sample_kwargs:
103-
random_seed = self.sample_kwargs["random_seed"]
104-
else:
105-
random_seed = None
119+
random_seed = self.sample_kwargs.get("random_seed", None)
106120

107121
self.build_model(X, y, coords)
108122
with self:
@@ -124,10 +138,17 @@ def predict(self, X):
124138
125139
"""
126140

141+
# Ensure random_seed is used in sample_prior_predictive() and
142+
# sample_posterior_predictive() if provided in sample_kwargs.
143+
random_seed = self.sample_kwargs.get("random_seed", None)
144+
127145
self._data_setter(X)
128146
with self: # sample with new input data
129147
post_pred = pm.sample_posterior_predictive(
130-
self.idata, var_names=["y_hat", "mu"], progressbar=False
148+
self.idata,
149+
var_names=["y_hat", "mu"],
150+
progressbar=False,
151+
random_seed=random_seed,
131152
)
132153
return post_pred
133154

@@ -180,7 +201,9 @@ class WeightedSumFitter(ModelBuilder):
180201
>>> y = np.asarray(sc['actual']).reshape((sc.shape[0], 1))
181202
>>> wsf = WeightedSumFitter(sample_kwargs={"progressbar": False})
182203
>>> wsf.fit(X,y)
183-
Inference ...
204+
<BLANKLINE>
205+
<BLANKLINE>
206+
Inference data...
184207
""" # noqa: W605
185208

186209
def build_model(self, X, y, coords):
@@ -190,8 +213,8 @@ def build_model(self, X, y, coords):
190213
with self:
191214
self.add_coords(coords)
192215
n_predictors = X.shape[1]
193-
X = pm.MutableData("X", X, dims=["obs_ind", "coeffs"])
194-
y = pm.MutableData("y", y[:, 0], dims="obs_ind")
216+
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
217+
y = pm.Data("y", y[:, 0], dims="obs_ind")
195218
# TODO: There we should allow user-specified priors here
196219
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
197220
# beta = pm.Dirichlet(
@@ -236,7 +259,9 @@ class LinearRegression(ModelBuilder):
236259
... 'obs_indx': np.arange(rd.shape[0])
237260
... },
238261
... )
239-
Inference...
262+
<BLANKLINE>
263+
<BLANKLINE>
264+
Inference data...
240265
""" # noqa: W605
241266

242267
def build_model(self, X, y, coords):
@@ -245,8 +270,8 @@ def build_model(self, X, y, coords):
245270
"""
246271
with self:
247272
self.add_coords(coords)
248-
X = pm.MutableData("X", X, dims=["obs_ind", "coeffs"])
249-
y = pm.MutableData("y", y[:, 0], dims="obs_ind")
273+
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
274+
y = pm.Data("y", y[:, 0], dims="obs_ind")
250275
beta = pm.Normal("beta", 0, 50, dims="coeffs")
251276
sigma = pm.HalfNormal("sigma", 1)
252277
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
@@ -288,6 +313,8 @@ class InstrumentalVariableRegression(ModelBuilder):
288313
... "eta": 2,
289314
... "lkj_sd": 2,
290315
... })
316+
<BLANKLINE>
317+
<BLANKLINE>
291318
Inference data...
292319
"""
293320

0 commit comments

Comments
 (0)