11import jax
22import jax .numpy as jnp
33from jaxtyping import Array , Float
4+ import pytest
45
56from flowMC .resource .nf_model .rqSpline import MaskedCouplingRQSpline
67from flowMC .resource .optimizer import Optimizer
@@ -78,22 +79,19 @@ def loss_fn(params: Float[Array, " n_dim"], data: dict = {}) -> Float:
7879
7980
8081class TestLocalStep :
81- def test_take_local_step (self ):
82+ @pytest .fixture (autouse = True )
83+ def setup (self ):
8284 n_chains = 5
8385 n_steps = 25
8486 n_dims = 2
8587 n_batch = 5
86-
8788 test_position = Buffer ("test_position" , (n_chains , n_steps , n_dims ), 1 )
8889 test_log_prob = Buffer ("test_log_prob" , (n_chains , n_steps ), 1 )
8990 test_acceptance = Buffer ("test_acceptance" , (n_chains , n_steps ), 1 )
90-
9191 mala_kernel = MALA (1.0 )
9292 grw_kernel = GaussianRandomWalk (1.0 )
9393 hmc_kernel = HMC (jnp .eye (n_dims ), 0.1 , 10 )
94-
9594 logpdf = LogPDF (log_posterior , n_dims = n_dims )
96-
9795 sampler_state = State (
9896 {
9997 "test_position" : "test_position" ,
@@ -102,8 +100,10 @@ def test_take_local_step(self):
102100 },
103101 name = "sampler_state" ,
104102 )
105-
106- resources = {
103+ self .n_batch = n_batch
104+ self .n_dims = n_dims
105+ self .test_position = test_position
106+ self .resources = {
107107 "test_position" : test_position ,
108108 "test_log_prob" : test_log_prob ,
109109 "test_acceptance" : test_acceptance ,
@@ -114,51 +114,74 @@ def test_take_local_step(self):
114114 "sampler_state" : sampler_state ,
115115 }
116116
117+ def test_take_local_step (self ):
117118 strategy = TakeSerialSteps (
118119 "logpdf" ,
119120 "MALA" ,
120121 "sampler_state" ,
121122 ["test_position" , "test_log_prob" , "test_acceptance" ],
122- n_batch ,
123+ self . n_batch ,
123124 )
124125 key = jax .random .PRNGKey (42 )
125- positions = test_position .data [:, 0 ]
126-
127- for i in range (n_batch ):
126+ positions = self .test_position .data [:, 0 ]
127+ for _ in range (self .n_batch ):
128128 key , subkey1 , subkey2 = jax .random .split (key , 3 )
129- _ , resources , positions = strategy (
129+ _ , self . resources , positions = strategy (
130130 rng_key = subkey1 ,
131- resources = resources ,
131+ resources = self . resources ,
132132 initial_position = positions ,
133- data = {"data" : jnp .arange (n_dims )},
133+ data = {"data" : jnp .arange (self . n_dims )},
134134 )
135-
136135 key , subkey1 , subkey2 = jax .random .split (key , 3 )
137136 strategy .set_current_position (0 )
138- _ , resources , positions = strategy (
137+ _ , self . resources , positions = strategy (
139138 rng_key = subkey1 ,
140- resources = resources ,
139+ resources = self . resources ,
141140 initial_position = positions ,
142- data = {"data" : jnp .arange (n_dims )},
141+ data = {"data" : jnp .arange (self . n_dims )},
143142 )
144-
145143 key , subkey1 , subkey2 = jax .random .split (key , 3 )
146144 strategy .kernel_name = "GRW"
147145 strategy .set_current_position (0 )
148- _ , resources , positions = strategy (
146+ _ , self . resources , positions = strategy (
149147 rng_key = subkey1 ,
150- resources = resources ,
148+ resources = self . resources ,
151149 initial_position = positions ,
152- data = {"data" : jnp .arange (n_dims )},
150+ data = {"data" : jnp .arange (self . n_dims )},
153151 )
154-
155152 strategy .kernel_name = "HMC"
156- _ , resources , positions = strategy (
153+ _ , self . resources , positions = strategy (
157154 rng_key = subkey1 ,
158- resources = resources ,
155+ resources = self . resources ,
159156 initial_position = positions ,
160- data = {"data" : jnp .arange (n_dims )},
157+ data = {"data" : jnp .arange (self .n_dims )},
158+ )
159+
160+ def test_take_local_step_chain_batch_size (self ):
161+ # Use a chain_batch_size smaller than the number of chains to trigger batching logic
162+ chain_batch_size = 2
163+ strategy = TakeSerialSteps (
164+ "logpdf" ,
165+ "MALA" ,
166+ "sampler_state" ,
167+ ["test_position" , "test_log_prob" , "test_acceptance" ],
168+ self .n_batch ,
169+ chain_batch_size = chain_batch_size ,
170+ )
171+ key = jax .random .PRNGKey (42 )
172+ positions = self .test_position .data [:, 0 ]
173+ # Run the strategy, which should use batching internally
174+ _ , _ , final_positions = strategy (
175+ rng_key = key ,
176+ resources = self .resources ,
177+ initial_position = positions ,
178+ data = {"data" : jnp .arange (self .n_dims )},
161179 )
180+ # Check that the output shape is correct
181+ assert final_positions .shape == (positions .shape [0 ], positions .shape [1 ])
182+ # Optionally, check that the buffer was updated for all chains
183+ assert isinstance (test_position := self .resources ["test_position" ], Buffer )
184+ assert test_position .data .shape [0 ] == positions .shape [0 ]
162185
163186
164187class TestNFStrategies :
0 commit comments