Skip to content
This repository was archived by the owner on Mar 19, 2021. It is now read-only.

Commit 73f4a11

Browse files
committed
ENH: add dtype variable to visualize histogram in traceplot
1 parent 58c43e8 commit 73f4a11

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

pystan/plots.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
logger = logging.getLogger('pystan')
44

55

6-
def traceplot(fit, pars, **kwargs):
6+
def traceplot(fit, pars, dtypes, **kwargs):
77
"""
8-
Use pymc's traceplot to display parameters.
9-
8+
Use pymc's traceplot to display parameters.
9+
1010
Additional arguments are passed to pymc.plots.traceplot.
1111
"""
1212
# FIXME: eventually put this in the StanFit object
@@ -16,4 +16,4 @@ def traceplot(fit, pars, **kwargs):
1616
except ImportError:
1717
logger.critical("matplotlib required for plotting.")
1818
raise
19-
return plots.traceplot(fit.extract(), pars, **kwargs)
19+
return plots.traceplot(fit.extract(dtypes=dtypes), pars, **kwargs)

pystan/stanfit4model.pyx

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ cdef class StanFit4Model:
484484

485485
# public methods
486486

487-
def plot(self, pars=None):
487+
def plot(self, pars=None, dtypes={}):
488488
"""Visualize samples from posterior distributions
489489
490490
Parameters
@@ -501,9 +501,9 @@ cdef class StanFit4Model:
501501
elif isinstance(pars, string_types):
502502
pars = [pars]
503503
pars = pystan.misc._remove_empty_pars(pars, self.sim['pars_oi'], self.sim['dims_oi'])
504-
return pystan.plots.traceplot(self, pars)
504+
return pystan.plots.traceplot(self, pars, dtypes)
505505

506-
def traceplot(self, pars=None):
506+
def traceplot(self, pars=None, dtypes={}):
507507
"""Visualize samples from posterior distributions
508508
509509
Parameters
@@ -512,9 +512,9 @@ cdef class StanFit4Model:
512512
parameter name(s); by default use all parameters of interest
513513
"""
514514
# FIXME: for now plot and traceplot do the same thing
515-
return self.plot(pars)
515+
return self.plot(pars, dtypes=dtypes)
516516

517-
def extract(self, pars=None, permuted=True, inc_warmup=False):
517+
def extract(self, pars=None, permuted=True, inc_warmup=False, dtypes={}):
518518
"""Extract samples in different forms for different parameters.
519519
520520
Parameters
@@ -567,7 +567,10 @@ cdef class StanFit4Model:
567567
for par in pars:
568568
sss = [pystan.misc._get_kept_samples(p, self.sim)
569569
for p in tidx[par]]
570-
s = {par: np.column_stack(sss)}
570+
ss = np.column_stack(sss)
571+
if par in dtypes.keys():
572+
ss = ss.astype(dtypes[par])
573+
s = {par: ss}
571574
extracted.update(s)
572575
par_idx = self.sim['pars_oi'].index(par)
573576
par_dim = self.sim['dims_oi'][par_idx]

0 commit comments

Comments
 (0)