@@ -484,7 +484,7 @@ cdef class StanFit4Model:
484
484
485
485
# public methods
486
486
487
- def plot (self , pars = None ):
487
+ def plot (self , pars = None , dtypes = {} ):
488
488
""" Visualize samples from posterior distributions
489
489
490
490
Parameters
@@ -501,9 +501,9 @@ cdef class StanFit4Model:
501
501
elif isinstance (pars, string_types):
502
502
pars = [pars]
503
503
pars = pystan.misc._remove_empty_pars(pars, self .sim[' pars_oi' ], self .sim[' dims_oi' ])
504
- return pystan.plots.traceplot(self , pars)
504
+ return pystan.plots.traceplot(self , pars, dtypes )
505
505
506
- def traceplot (self , pars = None ):
506
+ def traceplot (self , pars = None , dtypes = {} ):
507
507
""" Visualize samples from posterior distributions
508
508
509
509
Parameters
@@ -512,9 +512,9 @@ cdef class StanFit4Model:
512
512
parameter name(s); by default use all parameters of interest
513
513
"""
514
514
# FIXME: for now plot and traceplot do the same thing
515
- return self .plot(pars)
515
+ return self .plot(pars, dtypes = dtypes )
516
516
517
- def extract (self , pars = None , permuted = True , inc_warmup = False ):
517
+ def extract (self , pars = None , permuted = True , inc_warmup = False , dtypes = {} ):
518
518
""" Extract samples in different forms for different parameters.
519
519
520
520
Parameters
@@ -567,7 +567,10 @@ cdef class StanFit4Model:
567
567
for par in pars:
568
568
sss = [pystan.misc._get_kept_samples(p, self .sim)
569
569
for p in tidx[par]]
570
- s = {par: np.column_stack(sss)}
570
+ ss = np.column_stack(sss)
571
+ if par in dtypes.keys():
572
+ ss = ss.astype(dtypes[par])
573
+ s = {par: ss}
571
574
extracted.update(s)
572
575
par_idx = self .sim[' pars_oi' ].index(par)
573
576
par_dim = self .sim[' dims_oi' ][par_idx]
0 commit comments