|
50 | 50 | logger = logging.getLogger('pystan')
|
51 | 51 |
|
52 | 52 |
|
53 |
| -def _print_stanfit(fit, pars=None, probs=(0.025, 0.25, 0.5, 0.75, 0.975), digits_summary=2): |
54 |
| - if fit.mode == 1: |
55 |
| - return "Stan model '{}' is of mode 'test_grad';\n"\ |
56 |
| - "sampling is not conducted.".format(fit.model_name) |
57 |
| - elif fit.mode == 2: |
58 |
| - return "Stan model '{}' does not contain samples.".format(fit.model_name) |
59 |
| - if pars is None: |
60 |
| - pars = fit.sim['pars_oi'] |
61 |
| - fnames = fit.sim['fnames_oi'] |
62 |
| - else: |
63 |
| - # FIXME: does this case ever occur? |
64 |
| - # need a way of getting fnames matching specified pars |
65 |
| - raise NotImplementedError |
66 |
| - |
67 |
| - n_kept = [s - w for s, w in zip(fit.sim['n_save'], fit.sim['warmup2'])] |
68 |
| - header = "Inference for Stan model: {}.\n".format(fit.model_name) |
69 |
| - header += "{} chains, each with iter={}; warmup={}; thin={}; \n" |
70 |
| - header = header.format(fit.sim['chains'], fit.sim['iter'], fit.sim['warmup'], |
71 |
| - fit.sim['thin'], sum(n_kept)) |
72 |
| - header += "post-warmup draws per chain={}, total post-warmup draws={}.\n\n" |
73 |
| - header = header.format(n_kept[0], sum(n_kept)) |
74 |
| - footer = "\n\nSamples were drawn using {} at {}.\n"\ |
75 |
| - "For each parameter, n_eff is a crude measure of effective sample size,\n"\ |
76 |
| - "and Rhat is the potential scale reduction factor on split chains (at \n"\ |
77 |
| - "convergence, Rhat=1)." |
78 |
| - sampler = fit.sim['samples'][0]['args']['sampler_t'] |
79 |
| - date = fit.date.strftime('%c') # %c is locale's representation |
80 |
| - footer = footer.format(sampler, date) |
81 |
| - s = _summary(fit, pars, probs) |
82 |
| - body = _array_to_table(s['summary'], s['summary_rownames'], |
83 |
| - s['summary_colnames'], digits_summary) |
84 |
| - return header + body + footer |
| 53 | +def stansummary(fit, pars=None, probs=(0.025, 0.25, 0.5, 0.75, 0.975), digits_summary=2): |
| 54 | + """ |
| 55 | + Summary statistic table. |
| 56 | + |
| 57 | + Parameters |
| 58 | + ---------- |
| 59 | + fit : StanFit4Model object |
| 60 | + pars : str or sequence of str, optional |
| 61 | + Parameter names. By default use all parameters |
| 62 | + probs : sequence of float, optional |
| 63 | + Quantiles. By default, (0.025, 0.25, 0.5, 0.75, 0.975) |
| 64 | + digits_summary : int, optional |
| 65 | + Number of significant digits. By default, 2 |
85 | 66 |
|
| 67 | + Returns |
| 68 | + ------- |
| 69 | + summary : string |
| 70 | + Table includes mean, se_mean, sd, probs_0, ..., probs_n, n_eff and Rhat. |
| 71 | + |
| 72 | + Examples |
| 73 | + -------- |
| 74 | + >>> model_code = 'parameters {real y;} model {y ~ normal(0,1);}' |
| 75 | + >>> m = StanModel(model_code=model_code, model_name="example_model") |
| 76 | + >>> fit = m.sampling() |
| 77 | + >>> print(stansummary(fit)) |
| 78 | + Inference for Stan model: example_model. |
| 79 | + 4 chains, each with iter=2000; warmup=1000; thin=1; |
| 80 | + post-warmup draws per chain=1000, total post-warmup draws=4000. |
| 81 | + |
| 82 | + mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat |
| 83 | + y 0.01 0.03 1.0 -2.01 -0.68 0.02 0.72 1.97 1330 1.0 |
| 84 | + lp__ -0.5 0.02 0.68 -2.44 -0.66 -0.24 -0.05-5.5e-4 1555 1.0 |
| 85 | + |
| 86 | + Samples were drawn using NUTS at Thu Aug 17 00:52:25 2017. |
| 87 | + For each parameter, n_eff is a crude measure of effective sample size, |
| 88 | + and Rhat is the potential scale reduction factor on split chains (at |
| 89 | + convergence, Rhat=1). |
| 90 | + """ |
| 91 | + if fit.mode == 1: |
| 92 | + return "Stan model '{}' is of mode 'test_grad';\n"\ |
| 93 | + "sampling is not conducted.".format(fit.model_name) |
| 94 | + elif fit.mode == 2: |
| 95 | + return "Stan model '{}' does not contain samples.".format(fit.model_name) |
| 96 | + |
| 97 | + n_kept = [s - w for s, w in zip(fit.sim['n_save'], fit.sim['warmup2'])] |
| 98 | + header = "Inference for Stan model: {}.\n".format(fit.model_name) |
| 99 | + header += "{} chains, each with iter={}; warmup={}; thin={}; \n" |
| 100 | + header = header.format(fit.sim['chains'], fit.sim['iter'], fit.sim['warmup'], |
| 101 | + fit.sim['thin'], sum(n_kept)) |
| 102 | + header += "post-warmup draws per chain={}, total post-warmup draws={}.\n\n" |
| 103 | + header = header.format(n_kept[0], sum(n_kept)) |
| 104 | + footer = "\n\nSamples were drawn using {} at {}.\n"\ |
| 105 | + "For each parameter, n_eff is a crude measure of effective sample size,\n"\ |
| 106 | + "and Rhat is the potential scale reduction factor on split chains (at \n"\ |
| 107 | + "convergence, Rhat=1)." |
| 108 | + sampler = fit.sim['samples'][0]['args']['sampler_t'] |
| 109 | + date = fit.date.strftime('%c') # %c is locale's representation |
| 110 | + footer = footer.format(sampler, date) |
| 111 | + s = _summary(fit, pars, probs) |
| 112 | + body = _array_to_table(s['summary'], s['summary_rownames'], |
| 113 | + s['summary_colnames'], digits_summary) |
| 114 | + return header + body + footer |
| 115 | + |
| 116 | +def _print_stanfit(fit, pars=None, probs=(0.025, 0.25, 0.5, 0.75, 0.975), digits_summary=2): |
| 117 | + # warning added in PyStan 2.17.0 |
| 118 | + warnings.warn('Function `_print_stanfit` is deprecated and will be removed in a future version. '\ |
| 119 | + 'Use `stansummary` instead.', DeprecationWarning) |
| 120 | + return stansummary(fit, pars=pars, probs=probs, digits_summary=digits_summary) |
86 | 121 |
|
87 | 122 | def _array_to_table(arr, rownames, colnames, n_digits):
|
88 | 123 | """Print an array with row and column names
|
|
0 commit comments