13
13
_log = logging .getLogger (__name__ )
14
14
15
15
16
+ class NestedToMCMCAdapter :
17
+ """
18
+ Adapter to convert a NestedSampler object into an MCMC-compatible interface.
19
+
20
+ This class reshapes posterior samples from a NestedSampler into a chain-and-draw
21
+ structure expected by MCMC workflows, providing compatibility with downstream
22
+ tools like ArviZ for posterior analysis.
23
+
24
+ Parameters
25
+ ----------
26
+ nested_sampler : numpyro.contrib.nested_sampling.NestedSampler
27
+ The NestedSampler object containing posterior samples.
28
+ rng_key : jax.random.PRNGKey
29
+ The random key used for sampling.
30
+ num_samples : int
31
+ The total number of posterior samples to draw.
32
+ num_chains : int, optional
33
+ The number of artificial chains to create for MCMC compatibility (default is 1).
34
+ *args : tuple
35
+ Additional positional arguments required by the model (e.g., data, labels).
36
+ **kwargs : dict
37
+ Additional keyword arguments required by the model.
38
+
39
+ Attributes
40
+ ----------
41
+ samples : dict
42
+ Reshaped posterior samples organized by variable name.
43
+ thinning : int
44
+ Dummy thinning attribute for compatibility with MCMC.
45
+ sampler : NestedToMCMCAdapter
46
+ Mimics the sampler attribute of an MCMC object.
47
+ model : callable
48
+ The probabilistic model used in the NestedSampler.
49
+ _args : tuple
50
+ Positional arguments passed to the model.
51
+ _kwargs : dict
52
+ Keyword arguments passed to the model.
53
+
54
+ Methods
55
+ -------
56
+ get_samples(group_by_chain=True)
57
+ Returns posterior samples reshaped by chain or flattened if `group_by_chain` is False.
58
+ get_extra_fields(group_by_chain=True)
59
+ Provides dummy sampling statistics like accept probabilities, step sizes, and num_steps.
60
+ """
61
+
62
+ def __init__ (self , nested_sampler , rng_key , num_samples , * args , num_chains = 1 , ** kwargs ):
63
+ self .nested_sampler = nested_sampler
64
+ self .rng_key = rng_key
65
+ self .num_samples = num_samples
66
+ self .num_chains = num_chains
67
+ self .samples = self ._reshape_samples ()
68
+ self .thinning = 1
69
+ self .sampler = self
70
+ self .model = nested_sampler .model
71
+ self ._args = args
72
+ self ._kwargs = kwargs
73
+
74
+ def _reshape_samples (self ):
75
+ raw_samples = self .nested_sampler .get_samples (self .rng_key , self .num_samples )
76
+ samples_per_chain = self .num_samples // self .num_chains
77
+ return {
78
+ k : np .reshape (
79
+ v [: samples_per_chain * self .num_chains ],
80
+ (self .num_chains , samples_per_chain , * v .shape [1 :]),
81
+ )
82
+ for k , v in raw_samples .items ()
83
+ }
84
+
85
+ def get_samples (self , group_by_chain = True ):
86
+ if group_by_chain :
87
+ return self .samples
88
+ else :
89
+ # Flatten chains into a single dimension
90
+ return {k : v .reshape (- 1 , * v .shape [2 :]) for k , v in self .samples .items ()}
91
+
92
+ def get_extra_fields (self , group_by_chain = True ):
93
+ # Generate dummy fields since NestedSampler does not produce these
94
+ n_chains = self .num_chains
95
+ n_samples = self .num_samples // self .num_chains
96
+
97
+ # Create dummy values for extra fields
98
+ extra_fields = {
99
+ "accept_prob" : np .full ((n_chains , n_samples ), 1.0 ), # Assume all proposals are accepted
100
+ "step_size" : np .full ((n_chains , n_samples ), 0.1 ), # Dummy step size
101
+ "num_steps" : np .full ((n_chains , n_samples ), 10 ), # Dummy number of steps
102
+ }
103
+
104
+ if not group_by_chain :
105
+ # Flatten the chains into a single dimension
106
+ extra_fields = {k : v .reshape (- 1 , * v .shape [2 :]) for k , v in extra_fields .items ()}
107
+
108
+ return extra_fields
109
+
110
+
16
111
class NumPyroConverter :
17
112
"""Encapsulate NumPyro specific logic."""
18
113
@@ -37,6 +132,10 @@ def __init__(
37
132
dims = None ,
38
133
pred_dims = None ,
39
134
num_chains = 1 ,
135
+ rng_key = None ,
136
+ num_samples = 1000 ,
137
+ data = None ,
138
+ labels = None ,
40
139
):
41
140
"""Convert NumPyro data into an InferenceData object.
42
141
@@ -68,6 +167,15 @@ def __init__(
68
167
import numpyro
69
168
70
169
self .posterior = posterior
170
+ self .rng_key = rng_key
171
+ self .num_samples = num_samples
172
+
173
+ if isinstance (posterior , numpyro .contrib .nested_sampling .NestedSampler ):
174
+ posterior = NestedToMCMCAdapter (
175
+ posterior , rng_key , num_samples , num_chains = num_chains , data = data , labels = labels
176
+ )
177
+ self .posterior = posterior
178
+
71
179
self .prior = jax .device_get (prior )
72
180
self .posterior_predictive = jax .device_get (posterior_predictive )
73
181
self .predictions = predictions
@@ -340,6 +448,10 @@ def from_numpyro(
340
448
dims = None ,
341
449
pred_dims = None ,
342
450
num_chains = 1 ,
451
+ rng_key = None ,
452
+ num_samples = 1000 ,
453
+ data = None ,
454
+ labels = None ,
343
455
):
344
456
"""Convert NumPyro data into an InferenceData object.
345
457
@@ -383,4 +495,8 @@ def from_numpyro(
383
495
dims = dims ,
384
496
pred_dims = pred_dims ,
385
497
num_chains = num_chains ,
498
+ rng_key = rng_key ,
499
+ num_samples = num_samples ,
500
+ data = data ,
501
+ labels = labels ,
386
502
).to_inference_data ()
0 commit comments