@@ -110,7 +110,7 @@ def _data_setter(self, X) -> None:
110
110
pm .set_data ({"X" : X })
111
111
112
112
def fit (self , X , y , coords : Optional [Dict [str , Any ]] = None ) -> None :
113
- """Draw samples fromposterior , prior predictive, and posterior predictive
113
+ """Draw samples from posterior , prior predictive, and posterior predictive
114
114
distributions, placing them in the model's idata attribute.
115
115
"""
116
116
@@ -380,12 +380,19 @@ def fit(self, X, Z, y, t, coords, priors):
380
380
"""Draw samples from posterior, prior predictive, and posterior predictive
381
381
distributions.
382
382
"""
383
+
384
+ # Ensure random_seed is used in sample_prior_predictive() and
385
+ # sample_posterior_predictive() if provided in sample_kwargs.
386
+ random_seed = self .sample_kwargs .get ("random_seed" , None )
387
+
383
388
self .build_model (X , Z , y , t , coords , priors )
384
389
with self :
385
390
self .idata = pm .sample (** self .sample_kwargs )
386
- self .idata .extend (pm .sample_prior_predictive ())
391
+ self .idata .extend (pm .sample_prior_predictive (random_seed = random_seed ))
387
392
self .idata .extend (
388
- pm .sample_posterior_predictive (self .idata , progressbar = False )
393
+ pm .sample_posterior_predictive (
394
+ self .idata , progressbar = False , random_seed = random_seed
395
+ )
389
396
)
390
397
return self .idata
391
398
0 commit comments