-
-
Notifications
You must be signed in to change notification settings - Fork 69
Open
Labels
enhancementsNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed
Description
Currently, this raises an error:
# Code from notebooks/discrete_markov_chain.ipynb
chains = generate_chains(true_P, 100, n_chains=50).astype(float)
chains[10, 10] = np.nan
with pm.Model() as model:
x0 = pm.Categorical.dist(np.ones(3) / 3, size=(100,))
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
discrete_mc = DiscreteMarkovChain("MarkovChain", P=P, init_dist=x0, observed=chains)
idata = pm.sample()
Full Traceback
TypeError Traceback (most recent call last)
Cell In[57], line 5
3 P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
4 discrete_mc = DiscreteMarkovChain("MarkovChain", P=P, init_dist=x0, observed=chains)
----> 5 idata = pm.sample()
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/sampling/mcmc.py:782, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
779 msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
780 _log.warning(msg)
--> 782 provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)
783 exclusive_nuts = (
784 # User provided an instantiated NUTS step, and nothing else is needed
785 (not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS))
(...)
792 )
793 )
795 if nuts_sampler != "pymc":
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/sampling/mcmc.py:245, in assign_step_methods(model, step, methods)
243 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
244 selected_steps: dict[type[BlockedStep], list] = {}
--> 245 model_logp = model.logp()
247 for var in model.value_vars:
248 if var not in assigned_vars:
249 # determine if a gradient can be computed
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/model/core.py:691, in Model.logp(self, vars, jacobian, sum)
689 rv_logps: list[TensorVariable] = []
690 if rvs:
--> 691 rv_logps = transformed_conditional_logp(
692 rvs=rvs,
693 rvs_to_values=self.rvs_to_values,
694 rvs_to_transforms=self.rvs_to_transforms,
695 jacobian=jacobian,
696 )
697 assert isinstance(rv_logps, list)
699 # Replace random variables by their value variables in potential terms
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/logprob/basic.py:570, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
567 transform_rewrite = TransformValuesRewrite(values_to_transforms) # type: ignore[arg-type]
569 kwargs.setdefault("warn_rvs", False)
--> 570 temp_logp_terms = conditional_logp(
571 rvs_to_values,
572 extra_rewrites=transform_rewrite,
573 use_jacobian=jacobian,
574 **kwargs,
575 )
577 # The function returns the logp for every single value term we provided to it.
578 # This includes the extra values we plugged in above, so we filter those we
579 # actually wanted in the same order they were given in.
580 logp_terms = {}
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/logprob/basic.py:500, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
497 node_values = remapped_vars[: len(node_values)]
498 node_inputs = remapped_vars[len(node_values) :]
--> 500 node_logprobs = _logprob(
501 node.op,
502 node_values,
503 *node_inputs,
504 **kwargs,
505 )
507 if not isinstance(node_logprobs, list | tuple):
508 node_logprobs = [node_logprobs]
File ~/mambaforge/envs/pymc-extras/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/distributions/distribution.py:825, in partial_observed_rv_logprob(op, values, dist, mask, **kwargs)
823 joined_value = pt.set_subtensor(joined_value[mask], unobs_value)
824 joined_value = pt.set_subtensor(joined_value[antimask], obs_value)
--> 825 joined_logp = logp(dist, joined_value)
827 # If we have a univariate RV we can split apart the logp terms
828 if op.ndim_supp == 0:
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/logprob/basic.py:189, in logp(rv, value, warn_rvs, **kwargs)
187 value = pt.as_tensor_variable(value, dtype=rv.dtype)
188 try:
--> 189 return _logprob_helper(rv, value, **kwargs)
190 except NotImplementedError:
191 fgraph = construct_ir_fgraph({rv: value})
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/logprob/abstract.py:70, in _logprob_helper(rv, *values, **kwargs)
68 def _logprob_helper(rv, *values, **kwargs):
69 """Help call `_logprob` dispatcher."""
---> 70 logprob = _logprob(rv.owner.op, values, *rv.owner.inputs, **kwargs)
72 name = rv.name
73 if (not name) and (len(values) == 1):
File ~/mambaforge/envs/pymc-extras/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/Documents/Python/pymc_extras/pymc_extras/distributions/timeseries.py:263, in discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs)
260 indexes = [value[..., i : -(n_lags - i) if n_lags != i else None] for i in range(n_lags + 1)]
262 mc_logprob = logp(init_dist, value[..., :n_lags]).sum(axis=-1)
--> 263 mc_logprob += pt.log(P[tuple(indexes)]).sum(axis=-1)
265 # We cannot leave any RV in the logp graph, even if just for an assert
266 [init_dist_leading_dim] = constant_fold(
267 [pt.atleast_1d(init_dist).shape[0]], raise_not_constant=False
268 )
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pytensor/tensor/variable.py:557, in _tensor_py_operators.__getitem__(self, args)
554 advanced = True
556 if advanced:
--> 557 return pt.subtensor.advanced_subtensor(self, *args)
558 else:
559 if np.newaxis in args or NoneConst in args:
560 # `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
561 # broadcastable dimension at this location". Since PyTensor adds
(...)
564 # then uses recursion to apply any other indices and add any
565 # remaining new axes.
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pytensor/graph/op.py:293, in Op.__call__(self, name, return_list, *inputs, **kwargs)
249 def __call__(
250 self, *inputs: Any, name=None, return_list=False, **kwargs
251 ) -> Variable | list[Variable]:
252 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
253
254 This method is just a wrapper around :meth:`Op.make_node`.
(...)
291
292 """
--> 293 node = self.make_node(*inputs, **kwargs)
294 if name is not None:
295 if len(node.outputs) == 1:
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pytensor/tensor/subtensor.py:2817, in AdvancedSubtensor.make_node(self, x, *index)
2815 def make_node(self, x, *index):
2816 x = as_tensor_variable(x)
-> 2817 index = tuple(map(as_index_variable, index))
2819 # We create a fake symbolic shape tuple and identify the broadcast
2820 # dimensions from the shape result of this entire subtensor operation.
2821 with config.change_flags(compute_test_value="off"):
File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pytensor/tensor/subtensor.py:2774, in as_index_variable(idx)
2772 idx = as_tensor_variable(idx)
2773 if idx.type.dtype not in discrete_dtypes:
-> 2774 raise TypeError("index must be integers or a boolean mask")
2775 if idx.type.dtype == "bool" and idx.type.ndim == 0:
2776 raise NotImplementedError(
2777 "Boolean scalar indexing not implemented. "
2778 "Open an issue in https://github.com/pymc-devs/pytensor/issues if you need this behavior."
2779 )
TypeError: index must be integers or a boolean mask
Metadata
Metadata
Assignees
Labels
enhancementsNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed