Skip to content

Commit 05b557f

Browse files
authored
Implement CustomProgress that does not output empty divs when disabled (#7290)
* Replace Progress with CustomProgress * Add update method to CustomProgress * remove unused import * Replace Progress with CustomProgress * Add update method to CustomProgress * Remove some refreshes that slow things down * Remove some 'refresh' and make sure progress goes to 100%
1 parent 0216473 commit 05b557f

File tree

9 files changed

+130
-43
lines changed

9 files changed

+130
-43
lines changed

pymc/backends/arviz.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@
3232
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
3333
from pytensor.graph import ancestors
3434
from pytensor.tensor.sharedvar import SharedVariable
35-
from rich.progress import Console, Progress
35+
from rich.progress import Console
3636
from rich.theme import Theme
3737
from xarray import Dataset
3838

3939
import pymc
4040

4141
from pymc.model import Model, modelcontext
4242
from pymc.pytensorf import PointFunc, extract_obs_data
43-
from pymc.util import default_progress_theme, get_default_varnames
43+
from pymc.util import CustomProgress, default_progress_theme, get_default_varnames
4444

4545
if TYPE_CHECKING:
4646
from pymc.backends.base import MultiTrace
@@ -649,8 +649,10 @@ def apply_function_over_dataset(
649649
out_dict = _DefaultTrace(n_pts)
650650
indices = range(n_pts)
651651

652-
with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress:
653-
task = progress.add_task("Computing ...", total=n_pts, visible=progressbar)
652+
with CustomProgress(
653+
console=Console(theme=progressbar_theme), disable=not progressbar
654+
) as progress:
655+
task = progress.add_task("Computing ...", total=n_pts)
654656
for idx in indices:
655657
out = fn(posterior_pts[idx])
656658
fn.f.trust_input = True # If we arrive here the dtypes are valid

pymc/sampling/forward.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
)
4545
from pytensor.tensor.sharedvar import SharedVariable
4646
from rich.console import Console
47-
from rich.progress import Progress
47+
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
4848
from rich.theme import Theme
4949

5050
import pymc as pm
@@ -55,6 +55,7 @@
5555
from pymc.model import Model, modelcontext
5656
from pymc.pytensorf import compile_pymc
5757
from pymc.util import (
58+
CustomProgress,
5859
RandomState,
5960
_get_seeds_per_chain,
6061
default_progress_theme,
@@ -828,11 +829,21 @@ def sample_posterior_predictive(
828829
# All model variables have a name, but mypy does not know this
829830
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
830831
ppc_trace_t = _DefaultTrace(samples)
832+
833+
progress = CustomProgress(
834+
"[progress.description]{task.description}",
835+
BarColumn(),
836+
"[progress.percentage]{task.percentage:>3.0f}%",
837+
TimeRemainingColumn(),
838+
TextColumn("/"),
839+
TimeElapsedColumn(),
840+
console=Console(theme=progressbar_theme),
841+
disable=not progressbar,
842+
)
843+
831844
try:
832-
with Progress(
833-
console=Console(theme=progressbar_theme), disable=not progressbar
834-
) as progress:
835-
task = progress.add_task("Sampling ...", total=samples, visible=progressbar)
845+
with progress:
846+
task = progress.add_task("Sampling ...", completed=0, total=samples)
836847
for idx in np.arange(samples):
837848
if nchain > 1:
838849
# the trace object will either be a MultiTrace (and have _straces)...
@@ -854,6 +865,7 @@ def sample_posterior_predictive(
854865
ppc_trace_t.insert(k.name, v, idx)
855866

856867
progress.advance(task)
868+
progress.update(task, refresh=True, completed=samples)
857869

858870
except KeyboardInterrupt:
859871
pass

pymc/sampling/mcmc.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from arviz.data.base import make_attrs
3737
from pytensor.graph.basic import Variable
3838
from rich.console import Console
39-
from rich.progress import Progress
39+
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
4040
from rich.theme import Theme
4141
from threadpoolctl import threadpool_limits
4242
from typing_extensions import Protocol
@@ -65,6 +65,7 @@
6565
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
6666
from pymc.step_methods.hmc import quadpotential
6767
from pymc.util import (
68+
CustomProgress,
6869
RandomSeed,
6970
RandomState,
7071
_get_seeds_per_chain,
@@ -1075,14 +1076,28 @@ def _sample(
10751076
)
10761077
_pbar_data = {"chain": chain, "divergences": 0}
10771078
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
1078-
with Progress(console=Console(theme=progressbar_theme)) as progress:
1079+
1080+
progress = CustomProgress(
1081+
"[progress.description]{task.description}",
1082+
BarColumn(),
1083+
"[progress.percentage]{task.percentage:>3.0f}%",
1084+
TimeRemainingColumn(),
1085+
TextColumn("/"),
1086+
TimeElapsedColumn(),
1087+
console=Console(theme=progressbar_theme),
1088+
disable=not progressbar,
1089+
)
1090+
1091+
with progress:
10791092
try:
1080-
task = progress.add_task(_desc.format(**_pbar_data), total=draws, visible=progressbar)
1093+
task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws)
10811094
for it, diverging in enumerate(sampling_gen):
10821095
if it >= skip_first and diverging:
10831096
_pbar_data["divergences"] += 1
1084-
progress.update(task, refresh=True, advance=1)
1085-
progress.update(task, refresh=True, advance=1, completed=True)
1097+
progress.update(task, description=_desc.format(**_pbar_data), completed=it)
1098+
progress.update(
1099+
task, description=_desc.format(**_pbar_data), completed=draws, refresh=True
1100+
)
10861101
except KeyboardInterrupt:
10871102
pass
10881103

pymc/sampling/parallel.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
import numpy as np
2828

2929
from rich.console import Console
30-
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
30+
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
3131
from rich.theme import Theme
3232
from threadpoolctl import threadpool_limits
3333

3434
from pymc.blocking import DictToArrayBijection
3535
from pymc.exceptions import SamplingError
36-
from pymc.util import RandomSeed, default_progress_theme
36+
from pymc.util import CustomProgress, RandomSeed, default_progress_theme
3737

3838
logger = logging.getLogger(__name__)
3939

@@ -431,7 +431,7 @@ def __init__(
431431

432432
self._in_context = False
433433

434-
self._progress = Progress(
434+
self._progress = CustomProgress(
435435
"[progress.description]{task.description}",
436436
BarColumn(),
437437
"[progress.percentage]{task.percentage:>3.0f}%",
@@ -465,7 +465,6 @@ def __iter__(self):
465465
self._desc.format(self),
466466
completed=self._completed_draws,
467467
total=self._total_draws,
468-
visible=self._show_progress,
469468
)
470469

471470
while self._active:
@@ -476,7 +475,6 @@ def __iter__(self):
476475
self._divergences += 1
477476
progress.update(
478477
task,
479-
refresh=True,
480478
completed=self._completed_draws,
481479
total=self._total_draws,
482480
description=self._desc.format(self),

pymc/sampling/population.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import cloudpickle
2525
import numpy as np
2626

27-
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
27+
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
2828

2929
from pymc.backends.base import BaseTrace
3030
from pymc.initial_point import PointType
@@ -37,7 +37,7 @@
3737
StatsType,
3838
)
3939
from pymc.step_methods.metropolis import DEMetropolis
40-
from pymc.util import RandomSeed
40+
from pymc.util import CustomProgress, RandomSeed
4141

4242
__all__ = ()
4343

@@ -100,11 +100,10 @@ def _sample_population(
100100
progressbar=progressbar,
101101
)
102102

103-
with Progress() as progress:
104-
task = progress.add_task("[red]Sampling...", total=draws, visible=progressbar)
105-
103+
with CustomProgress(disable=not progressbar) as progress:
104+
task = progress.add_task("[red]Sampling...", total=draws)
106105
for _ in sampling:
107-
progress.update(task, advance=1, refresh=True)
106+
progress.update(task)
108107

109108
return
110109

@@ -175,20 +174,19 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True):
175174
)
176175
import multiprocessing
177176

178-
with Progress(
177+
with CustomProgress(
179178
"[progress.description]{task.description}",
180179
BarColumn(),
181180
"[progress.percentage]{task.percentage:>3.0f}%",
182181
TimeRemainingColumn(),
183182
TextColumn("/"),
184183
TimeElapsedColumn(),
184+
disable=not progressbar,
185185
) as self._progress:
186186
for c, stepper in enumerate(steppers):
187187
# enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
188188
# ):
189-
task = self._progress.add_task(
190-
description=f"Chain {c}", visible=progressbar
191-
)
189+
task = self._progress.add_task(description=f"Chain {c}")
192190
secondary_end, primary_end = multiprocessing.Pipe()
193191
stepper_dumps = cloudpickle.dumps(stepper, protocol=4)
194192
process = multiprocessing.Process(

pymc/smc/sampling.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
from arviz import InferenceData
2828
from rich.progress import (
29-
Progress,
3029
SpinnerColumn,
3130
TextColumn,
3231
TimeElapsedColumn,
@@ -41,7 +40,7 @@
4140
from pymc.sampling.parallel import _cpu_count
4241
from pymc.smc.kernels import IMH
4342
from pymc.stats.convergence import log_warnings, run_convergence_checks
44-
from pymc.util import RandomState, _get_seeds_per_chain
43+
from pymc.util import CustomProgress, RandomState, _get_seeds_per_chain
4544

4645

4746
def sample_smc(
@@ -369,13 +368,14 @@ def _sample_smc_int(
369368

370369

371370
def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
372-
with Progress(
371+
with CustomProgress(
373372
TextColumn("{task.description}"),
374373
SpinnerColumn(),
375374
TimeRemainingColumn(),
376375
TextColumn("/"),
377376
TimeElapsedColumn(),
378377
TextColumn("{task.fields[status]}"),
378+
disable=not progressbar,
379379
) as progress:
380380
futures = [] # keep track of the jobs
381381
with multiprocessing.Manager() as manager:
@@ -390,9 +390,7 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
390390
with ProcessPoolExecutor(max_workers=cores) as executor:
391391
for c in range(chains): # iterate over the jobs we need to run
392392
# set visible false so we don't have a lot of bars all at once:
393-
task_id = progress.add_task(
394-
f"Chain {c}", status="Stage: 0 Beta: 0", visible=progressbar
395-
)
393+
task_id = progress.add_task(f"Chain {c}", status="Stage: 0 Beta: 0")
396394
futures.append(
397395
executor.submit(
398396
_sample_smc_int,

pymc/tuning/starting.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737
from pymc.blocking import DictToArrayBijection, RaveledVars
3838
from pymc.initial_point import make_initial_point_fn
3939
from pymc.model import modelcontext
40-
from pymc.util import default_progress_theme, get_default_varnames, get_value_vars_from_user_vars
40+
from pymc.util import (
41+
CustomProgress,
42+
default_progress_theme,
43+
get_default_varnames,
44+
get_value_vars_from_user_vars,
45+
)
4146
from pymc.vartypes import discrete_types, typefilter
4247

4348
__all__ = ["find_MAP"]
@@ -219,13 +224,13 @@ def __init__(
219224
self.desc = "logp = {:,.5g}, ||grad|| = {:,.5g}"
220225
self.previous_x = None
221226
self.progressbar = progressbar
222-
self.progress = Progress(
227+
self.progress = CustomProgress(
223228
*Progress.get_default_columns(),
224229
TextColumn("{task.fields[loss]}"),
225230
console=Console(theme=progressbar_theme),
226231
disable=not progressbar,
227232
)
228-
self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="")
233+
self.task = self.progress.add_task("MAP", total=maxeval, loss="")
229234

230235
def __call__(self, x):
231236
neg_value = np.float64(self.logp_func(pm.floatX(x)))

pymc/util.py

+58
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytensor import Variable
2828
from pytensor.compile import SharedVariable
2929
from pytensor.graph.utils import ValidatingScratchpad
30+
from rich.progress import Progress
3031
from rich.theme import Theme
3132

3233
from pymc.exceptions import BlockModelAccessError
@@ -520,3 +521,60 @@ def makeiter(a):
520521
return a
521522
else:
522523
return [a]
524+
525+
526+
class CustomProgress(Progress):
527+
"""A child of Progress that allows to disable progress bars and its container
528+
529+
The implementation simply checks an `is_enabled` flag and generates the progress bar only if
530+
it's `True`.
531+
"""
532+
533+
def __init__(self, *args, **kwargs):
534+
self.is_enabled = kwargs.get("disable", None) is not True
535+
if self.is_enabled:
536+
super().__init__(*args, **kwargs)
537+
538+
def __enter__(self):
539+
if self.is_enabled:
540+
self.start()
541+
return self
542+
543+
def __exit__(self, exc_type, exc_val, exc_tb):
544+
if self.is_enabled:
545+
super().__exit__(exc_type, exc_val, exc_tb)
546+
547+
def add_task(self, *args, **kwargs):
548+
if self.is_enabled:
549+
return super().add_task(*args, **kwargs)
550+
return None
551+
552+
def advance(self, task_id, advance=1) -> None:
553+
if self.is_enabled:
554+
super().advance(task_id, advance)
555+
return None
556+
557+
def update(
558+
self,
559+
task_id,
560+
*,
561+
total=None,
562+
completed=None,
563+
advance=None,
564+
description=None,
565+
visible=None,
566+
refresh=False,
567+
**fields,
568+
):
569+
if self.is_enabled:
570+
super().update(
571+
task_id,
572+
total=total,
573+
completed=completed,
574+
advance=advance,
575+
description=description,
576+
visible=visible,
577+
refresh=refresh,
578+
**fields,
579+
)
580+
return None

0 commit comments

Comments
 (0)