@@ -484,13 +484,18 @@ cdef class StanFit4Model:
484
484
485
485
# public methods
486
486
487
- def plot (self , pars = None ):
487
+ def plot (self , pars = None , dtypes = None ):
488
488
""" Visualize samples from posterior distributions
489
489
490
490
Parameters
491
491
---------
492
492
pars : {str, sequence of str}
493
493
parameter name(s); by default use all parameters of interest
494
+ dtypes : dict
495
+ datatype of parameter(s).
496
+ If nothing is passed, np.float will be used for all parameters.
497
+ If np.int is specified, the histogram will be visualized, not but
498
+ kde.
494
499
495
500
Note
496
501
----
@@ -501,20 +506,25 @@ cdef class StanFit4Model:
501
506
elif isinstance (pars, string_types):
502
507
pars = [pars]
503
508
pars = pystan.misc._remove_empty_pars(pars, self .sim[' pars_oi' ], self .sim[' dims_oi' ])
504
- return pystan.plots.traceplot(self , pars)
509
+ return pystan.plots.traceplot(self , pars, dtypes )
505
510
506
- def traceplot (self , pars = None ):
511
+ def traceplot (self , pars = None , dtypes = None ):
507
512
""" Visualize samples from posterior distributions
508
513
509
514
Parameters
510
515
---------
511
516
pars : {str, sequence of str}, optional
512
517
parameter name(s); by default use all parameters of interest
518
+ dtypes : dict
519
+ datatype of parameter(s).
520
+ If nothing is passed, np.float will be used for all parameters.
521
+ If np.int is specified, the histogram will be visualized, not but
522
+ kde.
513
523
"""
514
524
# FIXME: for now plot and traceplot do the same thing
515
- return self .plot(pars)
525
+ return self .plot(pars, dtypes = dtypes )
516
526
517
- def extract (self , pars = None , permuted = True , inc_warmup = False ):
527
+ def extract (self , pars = None , permuted = True , inc_warmup = False , dtypes = None ):
518
528
""" Extract samples in different forms for different parameters.
519
529
520
530
Parameters
@@ -528,6 +538,9 @@ cdef class StanFit4Model:
528
538
inc_warmup : bool
529
539
If True, warmup samples are kept; otherwise they are
530
540
discarded. If `permuted` is True, `inc_warmup` is ignored.
541
+ dtypes : dict
542
+ datatype of parameter(s).
543
+ If nothing is passed, np.float will be used for all parameters.
531
544
532
545
Returns
533
546
-------
@@ -545,12 +558,16 @@ cdef class StanFit4Model:
545
558
self ._verify_has_samples()
546
559
if inc_warmup is True and permuted is True :
547
560
logging.warn(" `inc_warmup` ignored when `permuted` is True." )
561
+ if dtypes is None and permuted is False :
562
+ logging.warn(" `dtypes` ignored when `permuted` is False." )
548
563
549
564
if pars is None :
550
565
pars = self .sim[' pars_oi' ]
551
566
elif isinstance (pars, string_types):
552
567
pars = [pars]
553
568
pars = pystan.misc._remove_empty_pars(pars, self .sim[' pars_oi' ], self .sim[' dims_oi' ])
569
+ if dtypes is None :
570
+ dtypes = {}
554
571
555
572
allpars = self .sim[' pars_oi' ] + self .sim[' fnames_oi' ]
556
573
pystan.misc._check_pars(allpars, pars)
@@ -567,7 +584,10 @@ cdef class StanFit4Model:
567
584
for par in pars:
568
585
sss = [pystan.misc._get_kept_samples(p, self .sim)
569
586
for p in tidx[par]]
570
- s = {par: np.column_stack(sss)}
587
+ ss = np.column_stack(sss)
588
+ if par in dtypes.keys():
589
+ ss = ss.astype(dtypes[par])
590
+ s = {par: ss}
571
591
extracted.update(s)
572
592
par_idx = self .sim[' pars_oi' ].index(par)
573
593
par_dim = self .sim[' dims_oi' ][par_idx]
0 commit comments