|
36 | 36 | from arviz.data.base import make_attrs
|
37 | 37 | from pytensor.graph.basic import Variable
|
38 | 38 | from rich.console import Console
|
39 |
| -from rich.progress import Progress |
| 39 | +from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn |
40 | 40 | from rich.theme import Theme
|
41 | 41 | from threadpoolctl import threadpool_limits
|
42 | 42 | from typing_extensions import Protocol
|
|
65 | 65 | from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
|
66 | 66 | from pymc.step_methods.hmc import quadpotential
|
67 | 67 | from pymc.util import (
|
| 68 | + CustomProgress, |
68 | 69 | RandomSeed,
|
69 | 70 | RandomState,
|
70 | 71 | _get_seeds_per_chain,
|
@@ -1075,14 +1076,28 @@ def _sample(
|
1075 | 1076 | )
|
1076 | 1077 | _pbar_data = {"chain": chain, "divergences": 0}
|
1077 | 1078 | _desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
|
1078 |
| - with Progress(console=Console(theme=progressbar_theme)) as progress: |
| 1079 | + |
| 1080 | + progress = CustomProgress( |
| 1081 | + "[progress.description]{task.description}", |
| 1082 | + BarColumn(), |
| 1083 | + "[progress.percentage]{task.percentage:>3.0f}%", |
| 1084 | + TimeRemainingColumn(), |
| 1085 | + TextColumn("/"), |
| 1086 | + TimeElapsedColumn(), |
| 1087 | + console=Console(theme=progressbar_theme), |
| 1088 | + disable=not progressbar, |
| 1089 | + ) |
| 1090 | + |
| 1091 | + with progress: |
1079 | 1092 | try:
|
1080 |
| - task = progress.add_task(_desc.format(**_pbar_data), total=draws, visible=progressbar) |
| 1093 | + task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws) |
1081 | 1094 | for it, diverging in enumerate(sampling_gen):
|
1082 | 1095 | if it >= skip_first and diverging:
|
1083 | 1096 | _pbar_data["divergences"] += 1
|
1084 |
| - progress.update(task, refresh=True, advance=1) |
1085 |
| - progress.update(task, refresh=True, advance=1, completed=True) |
| 1097 | + progress.update(task, description=_desc.format(**_pbar_data), completed=it) |
| 1098 | + progress.update( |
| 1099 | + task, description=_desc.format(**_pbar_data), completed=draws, refresh=True |
| 1100 | + ) |
1086 | 1101 | except KeyboardInterrupt:
|
1087 | 1102 | pass
|
1088 | 1103 |
|
|
0 commit comments