Skip to content

Commit b4722b6

Browse files
committed
Add jax as dependency re-configure IV class fit call
1 parent b92b606 commit b4722b6

File tree

4 files changed

+42
-17
lines changed

4 files changed

+42
-17
lines changed

causalpy/pymc_experiments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1453,7 +1453,7 @@ def __init__(
14531453
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
14541454
"sigmas": [1, 1],
14551455
"eta": 2,
1456-
"lkj_sd": 2,
1456+
"lkj_sd": 1,
14571457
}
14581458
self.priors = priors
14591459
self.model.fit(

causalpy/pymc_models.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def build_model(self, X, Z, y, t, coords, priors):
340340
sigma=priors["sigmas"][1],
341341
dims="covariates",
342342
)
343-
sd_dist = pm.HalfCauchy.dist(beta=priors["lkj_sd"], shape=2)
343+
sd_dist = pm.Exponential.dist(priors["lkj_sd"], shape=2)
344344
chol, corr, sigmas = pm.LKJCholeskyCov(
345345
name="chol_cov",
346346
eta=priors["eta"],
@@ -366,27 +366,51 @@ def build_model(self, X, Z, y, t, coords, priors):
366366
shape=(X.shape[0], 2),
367367
)
368368

369-
def fit(self, X, Z, y, t, coords, priors, sample_ppc=True):
370-
"""Draw samples from posterior, prior predictive, and posterior predictive
371-
distributions.
372-
"""
373-
374-
# Ensure random_seed is used in sample_prior_predictive() and
375-
# sample_posterior_predictive() if provided in sample_kwargs.
376-
# Use JAX for ppc sampling of multivariate likelihood
369+
def sample_predictive_distribution(self, ppc_sampler="jax"):
370+
"""Function to sample the Multivariate Normal posterior predictive
371+
Likelihood term in the IV class. This can be slow without
372+
using the JAX sampler compilation method. If using the
373+
JAX sampler it will sample only the posterior predictive distribution.
374+
If using the PYMC sampler if will sample both the prior
375+
and posterior predictive distributions."""
377376
random_seed = self.sample_kwargs.get("random_seed", None)
378377

379-
self.build_model(X, Z, y, t, coords, priors)
380-
with self:
381-
self.idata = pm.sample(**self.sample_kwargs)
382-
if sample_ppc:
378+
if ppc_sampler == "jax":
379+
with self:
383380
self.idata.extend(
384381
pm.sample_posterior_predictive(
385382
self.idata,
386383
random_seed=random_seed,
387384
compile_kwargs={"mode": "JAX"},
388385
)
389386
)
387+
elif ppc_sampler == "pymc":
388+
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
389+
self.idata.extend(
390+
pm.sample_posterior_predictive(
391+
self.idata,
392+
random_seed=random_seed,
393+
)
394+
)
395+
396+
def fit(self, X, Z, y, t, coords, priors, ppc_sampler=None):
397+
"""Draw samples from posterior distribution and potentially
398+
from the prior and posterior predictive distributions. The
399+
fit call can take values for the
400+
ppc_sampler = ['jax', 'pymc', None]
401+
We default to None, so the user can determine if they wish
402+
to spend time sampling the posterior predictive distribution
403+
independently.
404+
"""
405+
406+
# Ensure random_seed is used in sample_prior_predictive() and
407+
# sample_posterior_predictive() if provided in sample_kwargs.
408+
# Use JAX for ppc sampling of multivariate likelihood
409+
410+
self.build_model(X, Z, y, t, coords, priors)
411+
with self:
412+
self.idata = pm.sample(**self.sample_kwargs)
413+
self.sample_predictive_distribution(ppc_sampler=ppc_sampler)
390414
return self.idata
391415

392416

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ dependencies:
1515
- seaborn>=0.11.2
1616
- statsmodels
1717
- xarray>=v2022.11.0
18+
- jax

0 commit comments

Comments
 (0)