@@ -340,7 +340,7 @@ def build_model(self, X, Z, y, t, coords, priors):
340
340
sigma = priors ["sigmas" ][1 ],
341
341
dims = "covariates" ,
342
342
)
343
- sd_dist = pm .HalfCauchy .dist (beta = priors ["lkj_sd" ], shape = 2 )
343
+ sd_dist = pm .Exponential .dist (priors ["lkj_sd" ], shape = 2 )
344
344
chol , corr , sigmas = pm .LKJCholeskyCov (
345
345
name = "chol_cov" ,
346
346
eta = priors ["eta" ],
@@ -366,27 +366,51 @@ def build_model(self, X, Z, y, t, coords, priors):
366
366
shape = (X .shape [0 ], 2 ),
367
367
)
368
368
369
- def fit (self , X , Z , y , t , coords , priors , sample_ppc = True ):
370
- """Draw samples from posterior, prior predictive, and posterior predictive
371
- distributions.
372
- """
373
-
374
- # Ensure random_seed is used in sample_prior_predictive() and
375
- # sample_posterior_predictive() if provided in sample_kwargs.
376
- # Use JAX for ppc sampling of multivariate likelihood
369
+ def sample_predictive_distribution (self , ppc_sampler = "jax" ):
370
+ """Function to sample the Multivariate Normal posterior predictive
371
+ Likelihood term in the IV class. This can be slow without
372
+ using the JAX sampler compilation method. If using the
373
+ JAX sampler it will sample only the posterior predictive distribution.
374
+ If using the PYMC sampler if will sample both the prior
375
+ and posterior predictive distributions."""
377
376
random_seed = self .sample_kwargs .get ("random_seed" , None )
378
377
379
- self .build_model (X , Z , y , t , coords , priors )
380
- with self :
381
- self .idata = pm .sample (** self .sample_kwargs )
382
- if sample_ppc :
378
+ if ppc_sampler == "jax" :
379
+ with self :
383
380
self .idata .extend (
384
381
pm .sample_posterior_predictive (
385
382
self .idata ,
386
383
random_seed = random_seed ,
387
384
compile_kwargs = {"mode" : "JAX" },
388
385
)
389
386
)
387
+ elif ppc_sampler == "pymc" :
388
+ self .idata .extend (pm .sample_prior_predictive (random_seed = random_seed ))
389
+ self .idata .extend (
390
+ pm .sample_posterior_predictive (
391
+ self .idata ,
392
+ random_seed = random_seed ,
393
+ )
394
+ )
395
+
396
+ def fit (self , X , Z , y , t , coords , priors , ppc_sampler = None ):
397
+ """Draw samples from posterior distribution and potentially
398
+ from the prior and posterior predictive distributions. The
399
+ fit call can take values for the
400
+ ppc_sampler = ['jax', 'pymc', None]
401
+ We default to None, so the user can determine if they wish
402
+ to spend time sampling the posterior predictive distribution
403
+ independently.
404
+ """
405
+
406
+ # Ensure random_seed is used in sample_prior_predictive() and
407
+ # sample_posterior_predictive() if provided in sample_kwargs.
408
+ # Use JAX for ppc sampling of multivariate likelihood
409
+
410
+ self .build_model (X , Z , y , t , coords , priors )
411
+ with self :
412
+ self .idata = pm .sample (** self .sample_kwargs )
413
+ self .sample_predictive_distribution (ppc_sampler = ppc_sampler )
390
414
return self .idata
391
415
392
416
0 commit comments