Skip to content

Commit 881030a

Browse files
Suggest var_names when using deprecated API for partial traces (#7289)
* Update docs when list is passed to trace in sample for partial trace * Change DeprecationWarning to ValueError Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Update test_partial_trace_unsupported --------- Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
1 parent 606d4ff commit 881030a

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

pymc/sampling/mcmc.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -630,10 +630,7 @@ def sample(
630630
else:
631631
kwargs["nuts"] = {"target_accept": kwargs.pop("target_accept")}
632632
if isinstance(trace, list):
633-
raise DeprecationWarning(
634-
"We have removed support for partial traces because it simplified things."
635-
" Please open an issue if & why this is a problem for you."
636-
)
633+
raise ValueError("Please use `var_names` keyword argument for partial traces.")
637634

638635
model = modelcontext(model)
639636
if not model.free_RVs:

tests/sampling/test_mcmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,11 +507,11 @@ def test_empty_model():
507507
error.match("any free variables")
508508

509509

510-
def test_partial_trace_unsupported():
510+
def test_partial_trace_with_trace_unsupported():
511511
with pm.Model() as model:
512512
a = pm.Normal("a", mu=0, sigma=1)
513513
b = pm.Normal("b", mu=0, sigma=1)
514-
with pytest.raises(DeprecationWarning, match="removed support"):
514+
with pytest.raises(ValueError, match="var_names"):
515515
pm.sample(trace=[a])
516516

517517

0 commit comments

Comments
 (0)