Skip to content

Commit 3207e03

Browse files
committed
Merge branch 'main' into summary
2 parents 22f5008 + 4af4af6 commit 3207e03

19 files changed

+6683
-29
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
name: Read the Docs Pull Request Preview
2+
on:
3+
pull_request_target:
4+
types:
5+
- opened
6+
7+
permissions:
8+
pull-requests: write
9+
10+
jobs:
11+
documentation-links:
12+
runs-on: ubuntu-latest
13+
steps:
14+
- uses: readthedocs/actions/preview@v1
15+
with:
16+
project-slug: "causalpy"

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ _build
77
build/
88
dist/
99
docs/_build/
10+
docs/build/
11+
docs/jupyter_execute/
1012
*.vscode
1113
.coverage
1214
*.jupyterlab-workspace

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ repos:
2222
exclude_types: [svg]
2323
- id: check-yaml
2424
- id: check-added-large-files
25+
exclude: &exclude_pattern 'iv_weak_instruments.ipynb'
2526
args: ["--maxkb=1500"]
2627
- repo: https://github.com/astral-sh/ruff-pre-commit
27-
rev: v0.4.8
28+
rev: v0.4.9
2829
hooks:
2930
# Run the linter
3031
- id: ruff

.readthedocs.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ version: 2
77

88
# Set the version of Python and other tools you might need
99
build:
10-
os: ubuntu-20.04
10+
os: ubuntu-lts-latest
1111
tools:
12-
python: "3.10"
12+
python: "3.11"
1313
# You can also specify other tool versions:
1414
# nodejs: "16"
1515
# rust: "1.55"

Makefile

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ lint:
1010
check_lint:
1111
ruff check .
1212
ruff format --diff --check .
13-
nbqa black --check .
14-
nbqa ruff .
1513
interrogate .
1614

1715
doctest:

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pip install git+https://github.com/pymc-labs/CausalPy.git
3434

3535
```python
3636
import causalpy as cp
37-
37+
import matplotlib.pyplot as plt
3838

3939
# Import and process data
4040
df = (
@@ -57,6 +57,8 @@ fig, ax = result.plot();
5757

5858
# Get a results summary
5959
result.summary()
60+
61+
plt.show()
6062
```
6163

6264
## Roadmap

causalpy/data/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"geolift1": {"filename": "geolift1.csv"},
3636
"risk": {"filename": "AJR2001.csv"},
3737
"nhefs": {"filename": "nhefs.csv"},
38+
"schoolReturns": {"filename": "schoolingReturns.csv"},
3839
}
3940

4041

causalpy/data/schoolingReturns.csv

Lines changed: 3011 additions & 0 deletions
Large diffs are not rendered by default.

causalpy/pymc_experiments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1450,7 +1450,7 @@ def __init__(
14501450
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
14511451
"sigmas": [1, 1],
14521452
"eta": 2,
1453-
"lkj_sd": 2,
1453+
"lkj_sd": 1,
14541454
}
14551455
self.priors = priors
14561456
self.model.fit(

causalpy/pymc_models.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,8 @@ class InstrumentalVariableRegression(ModelBuilder):
303303
... "mus": [[-2,4], [0.5, 3]],
304304
... "sigmas": [1, 1],
305305
... "eta": 2,
306-
... "lkj_sd": 2,
307-
... })
306+
... "lkj_sd": 1,
307+
... }, None)
308308
Inference data...
309309
"""
310310

@@ -340,7 +340,7 @@ def build_model(self, X, Z, y, t, coords, priors):
340340
sigma=priors["sigmas"][1],
341341
dims="covariates",
342342
)
343-
sd_dist = pm.HalfCauchy.dist(beta=priors["lkj_sd"], shape=2)
343+
sd_dist = pm.Exponential.dist(priors["lkj_sd"], shape=2)
344344
chol, corr, sigmas = pm.LKJCholeskyCov(
345345
name="chol_cov",
346346
eta=priors["eta"],
@@ -366,24 +366,52 @@ def build_model(self, X, Z, y, t, coords, priors):
366366
shape=(X.shape[0], 2),
367367
)
368368

369-
def fit(self, X, Z, y, t, coords, priors):
370-
"""Draw samples from posterior, prior predictive, and posterior predictive
371-
distributions.
369+
def sample_predictive_distribution(self, ppc_sampler="jax"):
370+
"""Function to sample the Multivariate Normal posterior predictive
371+
Likelihood term in the IV class. This can be slow without
372+
using the JAX sampler compilation method. If using the
373+
JAX sampler it will sample only the posterior predictive distribution.
374+
If using the PYMC sampler if will sample both the prior
375+
and posterior predictive distributions."""
376+
random_seed = self.sample_kwargs.get("random_seed", None)
377+
378+
if ppc_sampler == "jax":
379+
with self:
380+
self.idata.extend(
381+
pm.sample_posterior_predictive(
382+
self.idata,
383+
random_seed=random_seed,
384+
compile_kwargs={"mode": "JAX"},
385+
)
386+
)
387+
elif ppc_sampler == "pymc":
388+
with self:
389+
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
390+
self.idata.extend(
391+
pm.sample_posterior_predictive(
392+
self.idata,
393+
random_seed=random_seed,
394+
)
395+
)
396+
397+
def fit(self, X, Z, y, t, coords, priors, ppc_sampler=None):
398+
"""Draw samples from posterior distribution and potentially
399+
from the prior and posterior predictive distributions. The
400+
fit call can take values for the
401+
ppc_sampler = ['jax', 'pymc', None]
402+
We default to None, so the user can determine if they wish
403+
to spend time sampling the posterior predictive distribution
404+
independently.
372405
"""
373406

374407
# Ensure random_seed is used in sample_prior_predictive() and
375408
# sample_posterior_predictive() if provided in sample_kwargs.
376-
random_seed = self.sample_kwargs.get("random_seed", None)
409+
# Use JAX for ppc sampling of multivariate likelihood
377410

378411
self.build_model(X, Z, y, t, coords, priors)
379412
with self:
380413
self.idata = pm.sample(**self.sample_kwargs)
381-
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
382-
self.idata.extend(
383-
pm.sample_posterior_predictive(
384-
self.idata, progressbar=False, random_seed=random_seed
385-
)
386-
)
414+
self.sample_predictive_distribution(ppc_sampler=ppc_sampler)
387415
return self.idata
388416

389417

0 commit comments

Comments
 (0)