Skip to content

Commit 3f90553

Browse files
committed
Add NestedToMCMCAdapter to enable compatibility with ArviZ and MCMC workflows (#2391)
1 parent 529d795 commit 3f90553

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed

arviz/data/io_numpyro.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,101 @@
1313
_log = logging.getLogger(__name__)
1414

1515

16+
class NestedToMCMCAdapter:
17+
"""
18+
Adapter to convert a NestedSampler object into an MCMC-compatible interface.
19+
20+
This class reshapes posterior samples from a NestedSampler into a chain-and-draw
21+
structure expected by MCMC workflows, providing compatibility with downstream
22+
tools like ArviZ for posterior analysis.
23+
24+
Parameters
25+
----------
26+
nested_sampler : numpyro.contrib.nested_sampling.NestedSampler
27+
The NestedSampler object containing posterior samples.
28+
rng_key : jax.random.PRNGKey
29+
The random key used for sampling.
30+
num_samples : int
31+
The total number of posterior samples to draw.
32+
num_chains : int, optional
33+
The number of artificial chains to create for MCMC compatibility (default is 1).
34+
*args : tuple
35+
Additional positional arguments required by the model (e.g., data, labels).
36+
**kwargs : dict
37+
Additional keyword arguments required by the model.
38+
39+
Attributes
40+
----------
41+
samples : dict
42+
Reshaped posterior samples organized by variable name.
43+
thinning : int
44+
Dummy thinning attribute for compatibility with MCMC.
45+
sampler : NestedToMCMCAdapter
46+
Mimics the sampler attribute of an MCMC object.
47+
model : callable
48+
The probabilistic model used in the NestedSampler.
49+
_args : tuple
50+
Positional arguments passed to the model.
51+
_kwargs : dict
52+
Keyword arguments passed to the model.
53+
54+
Methods
55+
-------
56+
get_samples(group_by_chain=True)
57+
Returns posterior samples reshaped by chain or flattened if `group_by_chain` is False.
58+
get_extra_fields(group_by_chain=True)
59+
Provides dummy sampling statistics like accept probabilities, step sizes, and num_steps.
60+
"""
61+
62+
def __init__(self, nested_sampler, rng_key, num_samples, *args, num_chains=1, **kwargs):
63+
self.nested_sampler = nested_sampler
64+
self.rng_key = rng_key
65+
self.num_samples = num_samples
66+
self.num_chains = num_chains
67+
self.samples = self._reshape_samples()
68+
self.thinning = 1
69+
self.sampler = self
70+
self.model = nested_sampler.model
71+
self._args = args
72+
self._kwargs = kwargs
73+
74+
def _reshape_samples(self):
75+
raw_samples = self.nested_sampler.get_samples(self.rng_key, self.num_samples)
76+
samples_per_chain = self.num_samples // self.num_chains
77+
return {
78+
k: np.reshape(
79+
v[: samples_per_chain * self.num_chains],
80+
(self.num_chains, samples_per_chain, *v.shape[1:]),
81+
)
82+
for k, v in raw_samples.items()
83+
}
84+
85+
def get_samples(self, group_by_chain=True):
86+
if group_by_chain:
87+
return self.samples
88+
else:
89+
# Flatten chains into a single dimension
90+
return {k: v.reshape(-1, *v.shape[2:]) for k, v in self.samples.items()}
91+
92+
def get_extra_fields(self, group_by_chain=True):
93+
# Generate dummy fields since NestedSampler does not produce these
94+
n_chains = self.num_chains
95+
n_samples = self.num_samples // self.num_chains
96+
97+
# Create dummy values for extra fields
98+
extra_fields = {
99+
"accept_prob": np.full((n_chains, n_samples), 1.0), # Assume all proposals are accepted
100+
"step_size": np.full((n_chains, n_samples), 0.1), # Dummy step size
101+
"num_steps": np.full((n_chains, n_samples), 10), # Dummy number of steps
102+
}
103+
104+
if not group_by_chain:
105+
# Flatten the chains into a single dimension
106+
extra_fields = {k: v.reshape(-1, *v.shape[2:]) for k, v in extra_fields.items()}
107+
108+
return extra_fields
109+
110+
16111
class NumPyroConverter:
17112
"""Encapsulate NumPyro specific logic."""
18113

@@ -37,6 +132,10 @@ def __init__(
37132
dims=None,
38133
pred_dims=None,
39134
num_chains=1,
135+
rng_key=None,
136+
num_samples=1000,
137+
data=None,
138+
labels=None,
40139
):
41140
"""Convert NumPyro data into an InferenceData object.
42141
@@ -68,6 +167,15 @@ def __init__(
68167
import numpyro
69168

70169
self.posterior = posterior
170+
self.rng_key = rng_key
171+
self.num_samples = num_samples
172+
173+
if isinstance(posterior, numpyro.contrib.nested_sampling.NestedSampler):
174+
posterior = NestedToMCMCAdapter(
175+
posterior, rng_key, num_samples, num_chains=num_chains, data=data, labels=labels
176+
)
177+
self.posterior = posterior
178+
71179
self.prior = jax.device_get(prior)
72180
self.posterior_predictive = jax.device_get(posterior_predictive)
73181
self.predictions = predictions
@@ -340,6 +448,10 @@ def from_numpyro(
340448
dims=None,
341449
pred_dims=None,
342450
num_chains=1,
451+
rng_key=None,
452+
num_samples=1000,
453+
data=None,
454+
labels=None,
343455
):
344456
"""Convert NumPyro data into an InferenceData object.
345457
@@ -383,4 +495,8 @@ def from_numpyro(
383495
dims=dims,
384496
pred_dims=pred_dims,
385497
num_chains=num_chains,
498+
rng_key=rng_key,
499+
num_samples=num_samples,
500+
data=data,
501+
labels=labels,
386502
).to_inference_data()

0 commit comments

Comments
 (0)