Skip to content

Add NestedToMCMCAdapter to enable compatibility with MCMC (#2391) #2400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,101 @@
_log = logging.getLogger(__name__)


class NestedToMCMCAdapter:
"""
Adapter to convert a NestedSampler object into an MCMC-compatible interface.

This class reshapes posterior samples from a NestedSampler into a chain-and-draw
structure expected by MCMC workflows, providing compatibility with downstream
tools like ArviZ for posterior analysis.

Parameters
----------
nested_sampler : numpyro.contrib.nested_sampling.NestedSampler
The NestedSampler object containing posterior samples.
rng_key : jax.random.PRNGKey
The random key used for sampling.
num_samples : int
The total number of posterior samples to draw.
num_chains : int, optional
The number of artificial chains to create for MCMC compatibility (default is 1).
*args : tuple
Additional positional arguments required by the model (e.g., data, labels).
**kwargs : dict
Additional keyword arguments required by the model.

Attributes
----------
samples : dict
Reshaped posterior samples organized by variable name.
thinning : int
Dummy thinning attribute for compatibility with MCMC.
sampler : NestedToMCMCAdapter
Mimics the sampler attribute of an MCMC object.
model : callable
The probabilistic model used in the NestedSampler.
_args : tuple
Positional arguments passed to the model.
_kwargs : dict
Keyword arguments passed to the model.

Methods
-------
get_samples(group_by_chain=True)
Returns posterior samples reshaped by chain or flattened if `group_by_chain` is False.
get_extra_fields(group_by_chain=True)
Provides dummy sampling statistics like accept probabilities, step sizes, and num_steps.
"""

def __init__(self, nested_sampler, rng_key, num_samples, *args, num_chains=1, **kwargs):
self.nested_sampler = nested_sampler
self.rng_key = rng_key
self.num_samples = num_samples
self.num_chains = num_chains
self.samples = self._reshape_samples()
self.thinning = 1
self.sampler = self
self.model = nested_sampler.model
self._args = args
self._kwargs = kwargs

def _reshape_samples(self):
raw_samples = self.nested_sampler.get_samples(self.rng_key, self.num_samples)
samples_per_chain = self.num_samples // self.num_chains
return {
k: np.reshape(
v[: samples_per_chain * self.num_chains],
(self.num_chains, samples_per_chain, *v.shape[1:]),
)
for k, v in raw_samples.items()
}

def get_samples(self, group_by_chain=True):
if group_by_chain:
return self.samples
else:
# Flatten chains into a single dimension
return {k: v.reshape(-1, *v.shape[2:]) for k, v in self.samples.items()}

def get_extra_fields(self, group_by_chain=True):
# Generate dummy fields since NestedSampler does not produce these
n_chains = self.num_chains
n_samples = self.num_samples // self.num_chains

# Create dummy values for extra fields
extra_fields = {
"accept_prob": np.full((n_chains, n_samples), 1.0), # Assume all proposals are accepted
"step_size": np.full((n_chains, n_samples), 0.1), # Dummy step size
"num_steps": np.full((n_chains, n_samples), 10), # Dummy number of steps
}

if not group_by_chain:
# Flatten the chains into a single dimension
extra_fields = {k: v.reshape(-1, *v.shape[2:]) for k, v in extra_fields.items()}

return extra_fields


class NumPyroConverter:
"""Encapsulate NumPyro specific logic."""

Expand All @@ -37,6 +132,10 @@ def __init__(
dims=None,
pred_dims=None,
num_chains=1,
rng_key=None,
num_samples=1000,
data=None,
labels=None,
):
"""Convert NumPyro data into an InferenceData object.

Expand Down Expand Up @@ -68,6 +167,15 @@ def __init__(
import numpyro

self.posterior = posterior
self.rng_key = rng_key
self.num_samples = num_samples

if isinstance(posterior, numpyro.contrib.nested_sampling.NestedSampler):
posterior = NestedToMCMCAdapter(
posterior, rng_key, num_samples, num_chains=num_chains, data=data, labels=labels
)
self.posterior = posterior

self.prior = jax.device_get(prior)
self.posterior_predictive = jax.device_get(posterior_predictive)
self.predictions = predictions
Expand Down Expand Up @@ -340,6 +448,10 @@ def from_numpyro(
dims=None,
pred_dims=None,
num_chains=1,
rng_key=None,
num_samples=1000,
data=None,
labels=None,
):
"""Convert NumPyro data into an InferenceData object.

Expand Down Expand Up @@ -383,4 +495,8 @@ def from_numpyro(
dims=dims,
pred_dims=pred_dims,
num_chains=num_chains,
rng_key=rng_key,
num_samples=num_samples,
data=data,
labels=labels,
).to_inference_data()
Loading