Skip to content
This repository was archived by the owner on Mar 19, 2021. It is now read-only.

Commit 4098464

Browse files
authored
Merge pull request #359 from stan-dev/feature/enable_vars_in_misc._print_stanfit
Enable vars argument in misc._print_stanfit
2 parents cd7aba3 + 44bdf08 commit 4098464

File tree

4 files changed

+127
-36
lines changed

4 files changed

+127
-36
lines changed

pystan/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88

99
from pystan.api import stanc, stan
10-
from pystan.misc import read_rdump, stan_rdump
10+
from pystan.misc import read_rdump, stan_rdump, stansummary
1111
from pystan.model import StanModel
1212
from pystan.lookup import lookup
1313

pystan/misc.py

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,39 +50,74 @@
5050
logger = logging.getLogger('pystan')
5151

5252

53-
def _print_stanfit(fit, pars=None, probs=(0.025, 0.25, 0.5, 0.75, 0.975), digits_summary=2):
54-
if fit.mode == 1:
55-
return "Stan model '{}' is of mode 'test_grad';\n"\
56-
"sampling is not conducted.".format(fit.model_name)
57-
elif fit.mode == 2:
58-
return "Stan model '{}' does not contain samples.".format(fit.model_name)
59-
if pars is None:
60-
pars = fit.sim['pars_oi']
61-
fnames = fit.sim['fnames_oi']
62-
else:
63-
# FIXME: does this case ever occur?
64-
# need a way of getting fnames matching specified pars
65-
raise NotImplementedError
66-
67-
n_kept = [s - w for s, w in zip(fit.sim['n_save'], fit.sim['warmup2'])]
68-
header = "Inference for Stan model: {}.\n".format(fit.model_name)
69-
header += "{} chains, each with iter={}; warmup={}; thin={}; \n"
70-
header = header.format(fit.sim['chains'], fit.sim['iter'], fit.sim['warmup'],
71-
fit.sim['thin'], sum(n_kept))
72-
header += "post-warmup draws per chain={}, total post-warmup draws={}.\n\n"
73-
header = header.format(n_kept[0], sum(n_kept))
74-
footer = "\n\nSamples were drawn using {} at {}.\n"\
75-
"For each parameter, n_eff is a crude measure of effective sample size,\n"\
76-
"and Rhat is the potential scale reduction factor on split chains (at \n"\
77-
"convergence, Rhat=1)."
78-
sampler = fit.sim['samples'][0]['args']['sampler_t']
79-
date = fit.date.strftime('%c') # %c is locale's representation
80-
footer = footer.format(sampler, date)
81-
s = _summary(fit, pars, probs)
82-
body = _array_to_table(s['summary'], s['summary_rownames'],
83-
s['summary_colnames'], digits_summary)
84-
return header + body + footer
53+
def stansummary(fit, pars=None, probs=(0.025, 0.25, 0.5, 0.75, 0.975), digits_summary=2):
54+
"""
55+
Summary statistic table.
56+
57+
Parameters
58+
----------
59+
fit : StanFit4Model object
60+
pars : str or sequence of str, optional
61+
Parameter names. By default use all parameters
62+
probs : sequence of float, optional
63+
Quantiles. By default, (0.025, 0.25, 0.5, 0.75, 0.975)
64+
digits_summary : int, optional
65+
Number of significant digits. By default, 2
8566
67+
Returns
68+
-------
69+
summary : string
70+
Table includes mean, se_mean, sd, probs_0, ..., probs_n, n_eff and Rhat.
71+
72+
Examples
73+
--------
74+
>>> model_code = 'parameters {real y;} model {y ~ normal(0,1);}'
75+
>>> m = StanModel(model_code=model_code, model_name="example_model")
76+
>>> fit = m.sampling()
77+
>>> print(stansummary(fit))
78+
Inference for Stan model: example_model.
79+
4 chains, each with iter=2000; warmup=1000; thin=1;
80+
post-warmup draws per chain=1000, total post-warmup draws=4000.
81+
82+
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
83+
y 0.01 0.03 1.0 -2.01 -0.68 0.02 0.72 1.97 1330 1.0
84+
lp__ -0.5 0.02 0.68 -2.44 -0.66 -0.24 -0.05-5.5e-4 1555 1.0
85+
86+
Samples were drawn using NUTS at Thu Aug 17 00:52:25 2017.
87+
For each parameter, n_eff is a crude measure of effective sample size,
88+
and Rhat is the potential scale reduction factor on split chains (at
89+
convergence, Rhat=1).
90+
"""
91+
if fit.mode == 1:
92+
return "Stan model '{}' is of mode 'test_grad';\n"\
93+
"sampling is not conducted.".format(fit.model_name)
94+
elif fit.mode == 2:
95+
return "Stan model '{}' does not contain samples.".format(fit.model_name)
96+
97+
n_kept = [s - w for s, w in zip(fit.sim['n_save'], fit.sim['warmup2'])]
98+
header = "Inference for Stan model: {}.\n".format(fit.model_name)
99+
header += "{} chains, each with iter={}; warmup={}; thin={}; \n"
100+
header = header.format(fit.sim['chains'], fit.sim['iter'], fit.sim['warmup'],
101+
fit.sim['thin'], sum(n_kept))
102+
header += "post-warmup draws per chain={}, total post-warmup draws={}.\n\n"
103+
header = header.format(n_kept[0], sum(n_kept))
104+
footer = "\n\nSamples were drawn using {} at {}.\n"\
105+
"For each parameter, n_eff is a crude measure of effective sample size,\n"\
106+
"and Rhat is the potential scale reduction factor on split chains (at \n"\
107+
"convergence, Rhat=1)."
108+
sampler = fit.sim['samples'][0]['args']['sampler_t']
109+
date = fit.date.strftime('%c') # %c is locale's representation
110+
footer = footer.format(sampler, date)
111+
s = _summary(fit, pars, probs)
112+
body = _array_to_table(s['summary'], s['summary_rownames'],
113+
s['summary_colnames'], digits_summary)
114+
return header + body + footer
115+
116+
def _print_stanfit(fit, pars=None, probs=(0.025, 0.25, 0.5, 0.75, 0.975), digits_summary=2):
117+
# warning added in PyStan 2.17.0
118+
warnings.warn('Function `_print_stanfit` is deprecated and will be removed in a future version. '\
119+
'Use `stansummary` instead.', DeprecationWarning)
120+
return stansummary(fit, pars=pars, probs=probs, digits_summary=digits_summary)
86121

87122
def _array_to_table(arr, rownames, colnames, n_digits):
88123
"""Print an array with row and column names

pystan/stanfit4model.pyx

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,10 +614,10 @@ cdef class StanFit4Model:
614614

615615
def __unicode__(self):
616616
# for Python 2.x
617-
return pystan.misc._print_stanfit(self)
617+
return pystan.misc.stansummary(self)
618618

619619
def __str__(self):
620-
s = pystan.misc._print_stanfit(self)
620+
s = pystan.misc.stansummary(self)
621621
return s.encode('utf-8') if PY2 else s
622622

623623
def __repr__(self):
@@ -626,7 +626,46 @@ cdef class StanFit4Model:
626626
def __getitem__(self, key):
627627
extr = self.extract(pars=(key,))
628628
return extr[key]
629+
630+
def stansummary(self, pars=None, probs=(0.025, 0.25, 0.5, 0.75, 0.975), digits_summary=2):
631+
"""
632+
Summary statistic table.
629633
634+
Parameters
635+
----------
636+
fit : StanFit4Model object
637+
pars : str or sequence of str, optional
638+
Parameter names. By default use all parameters
639+
probs : sequence of float, optional
640+
Quantiles. By default, (0.025, 0.25, 0.5, 0.75, 0.975)
641+
digits_summary : int, optional
642+
Number of significant digits. By default, 2
643+
Returns
644+
-------
645+
summary : string
646+
Table includes mean, se_mean, sd, probs_0, ..., probs_n, n_eff and Rhat.
647+
648+
Examples
649+
--------
650+
>>> model_code = 'parameters {real y;} model {y ~ normal(0,1);}'
651+
>>> m = StanModel(model_code=model_code, model_name="example_model")
652+
>>> fit = m.sampling()
653+
>>> print(fit.stansummary())
654+
Inference for Stan model: example_model.
655+
4 chains, each with iter=2000; warmup=1000; thin=1;
656+
post-warmup draws per chain=1000, total post-warmup draws=4000.
657+
658+
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
659+
y 0.01 0.03 1.0 -2.01 -0.68 0.02 0.72 1.97 1330 1.0
660+
lp__ -0.5 0.02 0.68 -2.44 -0.66 -0.24 -0.05-5.5e-4 1555 1.0
661+
662+
Samples were drawn using NUTS at Thu Aug 17 00:52:25 2017.
663+
For each parameter, n_eff is a crude measure of effective sample size,
664+
and Rhat is the potential scale reduction factor on split chains (at
665+
convergence, Rhat=1).
666+
"""
667+
return pystan.misc.stansummary(fit=self, pars=pars, probs=probs, digits_summary=digits_summary)
668+
630669
def summary(self, pars=None, probs=None):
631670
return pystan.misc._summary(self, pars, probs)
632671

pystan/tests/test_misc_args.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class TestArgs(unittest.TestCase):
1010
@classmethod
1111
def setUpClass(cls):
12-
model_code = 'parameters {real y;} model {y ~ normal(0,1);}'
12+
model_code = 'parameters {real x;real y;real z;} model {x ~ normal(0,1);y ~ normal(0,1);z ~ normal(0,1);}'
1313
cls.model = pystan.StanModel(model_code=model_code)
1414

1515
def test_control(self):
@@ -23,3 +23,20 @@ def test_control(self):
2323
model.sampling(control=control_invalid)
2424
with assertRaisesRegex(ValueError, '`metric` must be one of'):
2525
model.sampling(control={'metric': 'lorem-ipsum'})
26+
27+
def test_print_summary(self):
28+
model = self.model
29+
fit = model.sampling(iter=100)
30+
31+
summary_full = pystan.misc.stansummary(fit)
32+
summary_one_par1 = pystan.misc.stansummary(fit, pars='z')
33+
summary_one_par2 = pystan.misc.stansummary(fit, pars=['z'])
34+
summary_pars = pystan.misc.stansummary(fit, pars=['x', 'y'])
35+
36+
self.assertNotEqual(summary_full, summary_one_par1)
37+
self.assertNotEqual(summary_full, summary_one_par2)
38+
self.assertNotEqual(summary_full, summary_pars)
39+
self.assertNotEqual(summary_one_par1, summary_pars)
40+
self.assertNotEqual(summary_one_par2, summary_pars)
41+
42+
self.assertEqual(summary_one_par1, summary_one_par2)

0 commit comments

Comments
 (0)