Skip to content

Support missing value interpolation when observed is DiscreteMarkovChain #547

@jessegrabowski

Description

@jessegrabowski

Currently, this raises an error:

# Code from notebooks/discrete_markov_chain.ipynb

chains = generate_chains(true_P, 100, n_chains=50).astype(float)
chains[10, 10] = np.nan

with pm.Model() as model:
    x0 = pm.Categorical.dist(np.ones(3) / 3, size=(100,))
    P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
    discrete_mc = DiscreteMarkovChain("MarkovChain", P=P, init_dist=x0, observed=chains)
    idata = pm.sample()
Full Traceback
TypeError                                 Traceback (most recent call last)
Cell In[57], line 5
      3 P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
      4 discrete_mc = DiscreteMarkovChain("MarkovChain", P=P, init_dist=x0, observed=chains)
----> 5 idata = pm.sample()

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/sampling/mcmc.py:782, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    779     msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
    780     _log.warning(msg)
--> 782 provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)
    783 exclusive_nuts = (
    784     # User provided an instantiated NUTS step, and nothing else is needed
    785     (not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS))
   (...)
    792     )
    793 )
    795 if nuts_sampler != "pymc":

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/sampling/mcmc.py:245, in assign_step_methods(model, step, methods)
    243 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
    244 selected_steps: dict[type[BlockedStep], list] = {}
--> 245 model_logp = model.logp()
    247 for var in model.value_vars:
    248     if var not in assigned_vars:
    249         # determine if a gradient can be computed

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/model/core.py:691, in Model.logp(self, vars, jacobian, sum)
    689 rv_logps: list[TensorVariable] = []
    690 if rvs:
--> 691     rv_logps = transformed_conditional_logp(
    692         rvs=rvs,
    693         rvs_to_values=self.rvs_to_values,
    694         rvs_to_transforms=self.rvs_to_transforms,
    695         jacobian=jacobian,
    696     )
    697     assert isinstance(rv_logps, list)
    699 # Replace random variables by their value variables in potential terms

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/logprob/basic.py:570, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    567     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore[arg-type]
    569 kwargs.setdefault("warn_rvs", False)
--> 570 temp_logp_terms = conditional_logp(
    571     rvs_to_values,
    572     extra_rewrites=transform_rewrite,
    573     use_jacobian=jacobian,
    574     **kwargs,
    575 )
    577 # The function returns the logp for every single value term we provided to it.
    578 # This includes the extra values we plugged in above, so we filter those we
    579 # actually wanted in the same order they were given in.
    580 logp_terms = {}

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/logprob/basic.py:500, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
    497 node_values = remapped_vars[: len(node_values)]
    498 node_inputs = remapped_vars[len(node_values) :]
--> 500 node_logprobs = _logprob(
    501     node.op,
    502     node_values,
    503     *node_inputs,
    504     **kwargs,
    505 )
    507 if not isinstance(node_logprobs, list | tuple):
    508     node_logprobs = [node_logprobs]

File ~/mambaforge/envs/pymc-extras/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/distributions/distribution.py:825, in partial_observed_rv_logprob(op, values, dist, mask, **kwargs)
    823 joined_value = pt.set_subtensor(joined_value[mask], unobs_value)
    824 joined_value = pt.set_subtensor(joined_value[antimask], obs_value)
--> 825 joined_logp = logp(dist, joined_value)
    827 # If we have a univariate RV we can split apart the logp terms
    828 if op.ndim_supp == 0:

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/logprob/basic.py:189, in logp(rv, value, warn_rvs, **kwargs)
    187 value = pt.as_tensor_variable(value, dtype=rv.dtype)
    188 try:
--> 189     return _logprob_helper(rv, value, **kwargs)
    190 except NotImplementedError:
    191     fgraph = construct_ir_fgraph({rv: value})

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pymc/logprob/abstract.py:70, in _logprob_helper(rv, *values, **kwargs)
     68 def _logprob_helper(rv, *values, **kwargs):
     69     """Help call `_logprob` dispatcher."""
---> 70     logprob = _logprob(rv.owner.op, values, *rv.owner.inputs, **kwargs)
     72     name = rv.name
     73     if (not name) and (len(values) == 1):

File ~/mambaforge/envs/pymc-extras/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/Documents/Python/pymc_extras/pymc_extras/distributions/timeseries.py:263, in discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs)
    260 indexes = [value[..., i : -(n_lags - i) if n_lags != i else None] for i in range(n_lags + 1)]
    262 mc_logprob = logp(init_dist, value[..., :n_lags]).sum(axis=-1)
--> 263 mc_logprob += pt.log(P[tuple(indexes)]).sum(axis=-1)
    265 # We cannot leave any RV in the logp graph, even if just for an assert
    266 [init_dist_leading_dim] = constant_fold(
    267     [pt.atleast_1d(init_dist).shape[0]], raise_not_constant=False
    268 )

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pytensor/tensor/variable.py:557, in _tensor_py_operators.__getitem__(self, args)
    554                 advanced = True
    556 if advanced:
--> 557     return pt.subtensor.advanced_subtensor(self, *args)
    558 else:
    559     if np.newaxis in args or NoneConst in args:
    560         # `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
    561         # broadcastable dimension at this location".  Since PyTensor adds
   (...)
    564         # then uses recursion to apply any other indices and add any
    565         # remaining new axes.

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pytensor/graph/op.py:293, in Op.__call__(self, name, return_list, *inputs, **kwargs)
    249 def __call__(
    250     self, *inputs: Any, name=None, return_list=False, **kwargs
    251 ) -> Variable | list[Variable]:
    252     r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    253 
    254     This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    291 
    292     """
--> 293     node = self.make_node(*inputs, **kwargs)
    294     if name is not None:
    295         if len(node.outputs) == 1:

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pytensor/tensor/subtensor.py:2817, in AdvancedSubtensor.make_node(self, x, *index)
   2815 def make_node(self, x, *index):
   2816     x = as_tensor_variable(x)
-> 2817     index = tuple(map(as_index_variable, index))
   2819     # We create a fake symbolic shape tuple and identify the broadcast
   2820     # dimensions from the shape result of this entire subtensor operation.
   2821     with config.change_flags(compute_test_value="off"):

File ~/mambaforge/envs/pymc-extras/lib/python3.12/site-packages/pytensor/tensor/subtensor.py:2774, in as_index_variable(idx)
   2772 idx = as_tensor_variable(idx)
   2773 if idx.type.dtype not in discrete_dtypes:
-> 2774     raise TypeError("index must be integers or a boolean mask")
   2775 if idx.type.dtype == "bool" and idx.type.ndim == 0:
   2776     raise NotImplementedError(
   2777         "Boolean scalar indexing not implemented. "
   2778         "Open an issue in https://github.com/pymc-devs/pytensor/issues if you need this behavior."
   2779     )

TypeError: index must be integers or a boolean mask

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementsNew feature or requesthelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions