Skip to content

Commit 0ffe348

Browse files
authored
Merge pull request #335 from anevolbap/issue_329-IV-fit-method
Fix #329: Add `random_seed` to the `fit` method.
2 parents 7d21e99 + f06e3d7 commit 0ffe348

File tree

2 files changed

+10
-61
lines changed

2 files changed

+10
-61
lines changed

causalpy/pymc_models.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _data_setter(self, X) -> None:
110110
pm.set_data({"X": X})
111111

112112
def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
113-
"""Draw samples fromposterior, prior predictive, and posterior predictive
113+
"""Draw samples from posterior, prior predictive, and posterior predictive
114114
distributions, placing them in the model's idata attribute.
115115
"""
116116

@@ -380,12 +380,19 @@ def fit(self, X, Z, y, t, coords, priors):
380380
"""Draw samples from posterior, prior predictive, and posterior predictive
381381
distributions.
382382
"""
383+
384+
# Ensure random_seed is used in sample_prior_predictive() and
385+
# sample_posterior_predictive() if provided in sample_kwargs.
386+
random_seed = self.sample_kwargs.get("random_seed", None)
387+
383388
self.build_model(X, Z, y, t, coords, priors)
384389
with self:
385390
self.idata = pm.sample(**self.sample_kwargs)
386-
self.idata.extend(pm.sample_prior_predictive())
391+
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
387392
self.idata.extend(
388-
pm.sample_posterior_predictive(self.idata, progressbar=False)
393+
pm.sample_posterior_predictive(
394+
self.idata, progressbar=False, random_seed=random_seed
395+
)
389396
)
390397
return self.idata
391398

docs/source/_static/interrogate_badge.svg

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)