Skip to content
This repository was archived by the owner on Mar 19, 2021. It is now read-only.

Commit 9f09140

Browse files
authored
Merge pull request #340 from TaskeHAMANO/feature/add_dtype
ENH: add dtype variable to visualize histogram in traceplot
2 parents 3a9181a + e1350a5 commit 9f09140

File tree

3 files changed

+40
-10
lines changed

3 files changed

+40
-10
lines changed

pystan/plots.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
logger = logging.getLogger('pystan')
44

55

6-
def traceplot(fit, pars, **kwargs):
6+
def traceplot(fit, pars, dtypes, **kwargs):
77
"""
8-
Use pymc's traceplot to display parameters.
9-
8+
Use pymc's traceplot to display parameters.
9+
1010
Additional arguments are passed to pymc.plots.traceplot.
1111
"""
1212
# FIXME: eventually put this in the StanFit object
@@ -16,4 +16,4 @@ def traceplot(fit, pars, **kwargs):
1616
except ImportError:
1717
logger.critical("matplotlib required for plotting.")
1818
raise
19-
return plots.traceplot(fit.extract(), pars, **kwargs)
19+
return plots.traceplot(fit.extract(dtypes=dtypes), pars, **kwargs)

pystan/stanfit4model.pyx

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,13 +484,18 @@ cdef class StanFit4Model:
484484

485485
# public methods
486486

487-
def plot(self, pars=None):
487+
def plot(self, pars=None, dtypes=None):
488488
"""Visualize samples from posterior distributions
489489
490490
Parameters
491491
---------
492492
pars : {str, sequence of str}
493493
parameter name(s); by default use all parameters of interest
494+
dtypes : dict
495+
datatype of parameter(s).
496+
If nothing is passed, np.float will be used for all parameters.
497+
If np.int is specified, the histogram will be visualized, not but
498+
kde.
494499
495500
Note
496501
----
@@ -501,20 +506,25 @@ cdef class StanFit4Model:
501506
elif isinstance(pars, string_types):
502507
pars = [pars]
503508
pars = pystan.misc._remove_empty_pars(pars, self.sim['pars_oi'], self.sim['dims_oi'])
504-
return pystan.plots.traceplot(self, pars)
509+
return pystan.plots.traceplot(self, pars, dtypes)
505510

506-
def traceplot(self, pars=None):
511+
def traceplot(self, pars=None, dtypes=None):
507512
"""Visualize samples from posterior distributions
508513
509514
Parameters
510515
---------
511516
pars : {str, sequence of str}, optional
512517
parameter name(s); by default use all parameters of interest
518+
dtypes : dict
519+
datatype of parameter(s).
520+
If nothing is passed, np.float will be used for all parameters.
521+
If np.int is specified, the histogram will be visualized, not but
522+
kde.
513523
"""
514524
# FIXME: for now plot and traceplot do the same thing
515-
return self.plot(pars)
525+
return self.plot(pars, dtypes=dtypes)
516526

517-
def extract(self, pars=None, permuted=True, inc_warmup=False):
527+
def extract(self, pars=None, permuted=True, inc_warmup=False, dtypes=None):
518528
"""Extract samples in different forms for different parameters.
519529
520530
Parameters
@@ -528,6 +538,9 @@ cdef class StanFit4Model:
528538
inc_warmup : bool
529539
If True, warmup samples are kept; otherwise they are
530540
discarded. If `permuted` is True, `inc_warmup` is ignored.
541+
dtypes : dict
542+
datatype of parameter(s).
543+
If nothing is passed, np.float will be used for all parameters.
531544
532545
Returns
533546
-------
@@ -545,12 +558,16 @@ cdef class StanFit4Model:
545558
self._verify_has_samples()
546559
if inc_warmup is True and permuted is True:
547560
logging.warn("`inc_warmup` ignored when `permuted` is True.")
561+
if dtypes is None and permuted is False:
562+
logging.warn("`dtypes` ignored when `permuted` is False.")
548563

549564
if pars is None:
550565
pars = self.sim['pars_oi']
551566
elif isinstance(pars, string_types):
552567
pars = [pars]
553568
pars = pystan.misc._remove_empty_pars(pars, self.sim['pars_oi'], self.sim['dims_oi'])
569+
if dtypes is None:
570+
dtypes = {}
554571

555572
allpars = self.sim['pars_oi'] + self.sim['fnames_oi']
556573
pystan.misc._check_pars(allpars, pars)
@@ -567,7 +584,10 @@ cdef class StanFit4Model:
567584
for par in pars:
568585
sss = [pystan.misc._get_kept_samples(p, self.sim)
569586
for p in tidx[par]]
570-
s = {par: np.column_stack(sss)}
587+
ss = np.column_stack(sss)
588+
if par in dtypes.keys():
589+
ss = ss.astype(dtypes[par])
590+
s = {par: ss}
571591
extracted.update(s)
572592
par_idx = self.sim['pars_oi'].index(par)
573593
par_dim = self.sim['dims_oi'][par_idx]

pystan/tests/test_extract.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,13 @@ def test_extract_thin(self):
8686
ss = fit.extract(inc_warmup=True, permuted=False)
8787
self.assertEqual(ss.shape, (1000, 4, 9))
8888
self.assertTrue((~np.isnan(ss)).all())
89+
90+
def test_extract_dtype(self):
91+
dtypes = {"alpha": np.int, "beta": np.int}
92+
ss = self.fit.extract(dtypes = dtypes)
93+
alpha = ss['alpha']
94+
beta = ss['beta']
95+
lp__ = ss['lp__']
96+
self.assertEqual(alpha.dtype, np.dtype(np.int))
97+
self.assertEqual(beta.dtype, np.dtype(np.int))
98+
self.assertEqual(lp__.dtype, np.dtype(np.float))

0 commit comments

Comments
 (0)