Skip to content

Commit c502bf9

Browse files
authored
Merge pull request #232 from kazewong/batch_take_step
Batch take step
2 parents c326bb7 + 9cd7427 commit c502bf9

File tree

6 files changed

+1173
-1104
lines changed

6 files changed

+1173
-1104
lines changed

src/flowMC/resource/nf_model/NF_proposal.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,21 @@ def scan_sample(
9191
def body(carry, data):
9292
(
9393
rng_key,
94-
position_current,
95-
log_prob_current,
96-
log_prob_nf_current,
94+
position_initial,
95+
log_prob_initial,
96+
log_prob_nf_initial,
9797
) = carry
98-
(position_proposed, log_prob_proposal, log_prob_nf_proposal) = data
99-
98+
(position_proposal, log_prob_proposal, log_prob_nf_proposal) = data
10099
rng_key, subkey = random.split(rng_key)
101-
ratio = (log_prob_proposal - log_prob_current) - (
102-
log_prob_nf_proposal - log_prob_nf_current
100+
ratio = (log_prob_proposal - log_prob_initial) - (
101+
log_prob_nf_proposal - log_prob_nf_initial
103102
)
104103
uniform_random = jnp.log(jax.random.uniform(subkey))
105104
do_accept = uniform_random < ratio
106-
position_current = jnp.where(do_accept, position_proposed, position_current)
107-
log_prob_current = jnp.where(do_accept, log_prob_proposal, log_prob_current)
105+
position_current = jnp.where(do_accept, position_proposal, position_initial)
106+
log_prob_current = jnp.where(do_accept, log_prob_proposal, log_prob_initial)
108107
log_prob_nf_current = jnp.where(
109-
do_accept, log_prob_nf_proposal, log_prob_nf_current
108+
do_accept, log_prob_nf_proposal, log_prob_nf_initial
110109
)
111110

112111
return (rng_key, position_current, log_prob_current, log_prob_nf_current), (

src/flowMC/resource_strategy_bundle/RQSpline_MALA.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
n_production_loops: int,
4343
n_epochs: int,
4444
mala_step_size: float = 1e-1,
45+
chain_batch_size: int = 0,
4546
rq_spline_hidden_units: list[int] = [32, 32],
4647
rq_spline_n_bins: int = 8,
4748
rq_spline_n_layers: int = 4,
@@ -135,6 +136,7 @@ def __init__(
135136
["target_positions", "target_log_prob", "target_local_accs"],
136137
n_local_steps,
137138
thinning=local_thinning,
139+
chain_batch_size=chain_batch_size,
138140
verbose=verbose,
139141
)
140142

@@ -145,6 +147,7 @@ def __init__(
145147
["target_positions", "target_log_prob", "target_global_accs"],
146148
n_global_steps,
147149
thinning=global_thinning,
150+
chain_batch_size=chain_batch_size,
148151
verbose=verbose,
149152
)
150153

src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
n_production_loops: int,
4949
n_epochs: int,
5050
mala_step_size: float = 1e-1,
51+
chain_batch_size: int = 0,
5152
rq_spline_hidden_units: list[int] = [32, 32],
5253
rq_spline_n_bins: int = 8,
5354
rq_spline_n_layers: int = 4,
@@ -165,6 +166,7 @@ def __init__(
165166
["target_positions", "target_log_prob", "target_local_accs"],
166167
n_local_steps,
167168
thinning=local_thinning,
169+
chain_batch_size=chain_batch_size,
168170
verbose=verbose,
169171
)
170172

@@ -175,6 +177,7 @@ def __init__(
175177
["target_positions", "target_log_prob", "target_global_accs"],
176178
n_global_steps,
177179
thinning=global_thinning,
180+
chain_batch_size=chain_batch_size,
178181
verbose=verbose,
179182
)
180183

src/flowMC/strategy/take_steps.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from flowMC.strategy.base import Strategy
77
from jaxtyping import Array, Float, PRNGKeyArray
88
import jax
9+
import jax.numpy as jnp
910
import equinox as eqx
1011
from abc import abstractmethod
1112

@@ -18,6 +19,7 @@ class TakeSteps(Strategy):
1819
n_steps: int
1920
current_position: int
2021
thinning: int
22+
chain_batch_size: int # If vmap over a large number of chains is memory bounded, this splits the computation
2123
verbose: bool
2224

2325
def __init__(
@@ -28,6 +30,7 @@ def __init__(
2830
buffer_names: list[str],
2931
n_steps: int,
3032
thinning: int = 1,
33+
chain_batch_size: int = 0,
3134
verbose: bool = False,
3235
):
3336
self.logpdf_name = logpdf_name
@@ -37,6 +40,7 @@ def __init__(
3740
self.n_steps = n_steps
3841
self.current_position = 0
3942
self.thinning = thinning
43+
self.chain_batch_size = chain_batch_size
4044
self.verbose = verbose
4145

4246
@abstractmethod
@@ -98,11 +102,34 @@ def __call__(
98102

99103
# Filter jit will bypass the compilation of
100104
# the function if not clearing the cache
101-
positions, log_probs, do_accepts = eqx.filter_jit(
102-
eqx.filter_vmap(
103-
jax.tree_util.Partial(self.sample, kernel), in_axes=(0, 0, None, None)
104-
)
105-
)(subkey, initial_position, logpdf, data)
105+
n_chains = initial_position.shape[0]
106+
if self.chain_batch_size > 1 and n_chains > self.chain_batch_size:
107+
positions_list = []
108+
log_probs_list = []
109+
do_accepts_list = []
110+
for i in range(0, n_chains, self.chain_batch_size):
111+
batch_slice = slice(i, min(i + self.chain_batch_size, n_chains))
112+
subkey_batch = subkey[batch_slice]
113+
initial_position_batch = initial_position[batch_slice]
114+
positions_batch, log_probs_batch, do_accepts_batch = eqx.filter_jit(
115+
eqx.filter_vmap(
116+
jax.tree_util.Partial(self.sample, kernel),
117+
in_axes=(0, 0, None, None),
118+
)
119+
)(subkey_batch, initial_position_batch, logpdf, data)
120+
positions_list.append(positions_batch)
121+
log_probs_list.append(log_probs_batch)
122+
do_accepts_list.append(do_accepts_batch)
123+
positions = jnp.concatenate(positions_list, axis=0)
124+
log_probs = jnp.concatenate(log_probs_list, axis=0)
125+
do_accepts = jnp.concatenate(do_accepts_list, axis=0)
126+
else:
127+
positions, log_probs, do_accepts = eqx.filter_jit(
128+
eqx.filter_vmap(
129+
jax.tree_util.Partial(self.sample, kernel),
130+
in_axes=(0, 0, None, None),
131+
)
132+
)(subkey, initial_position, logpdf, data)
106133

107134
positions = positions[:, :: self.thinning]
108135
log_probs = log_probs[:, :: self.thinning]

test/unit/test_strategies.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import jax
22
import jax.numpy as jnp
33
from jaxtyping import Array, Float
4+
import pytest
45

56
from flowMC.resource.nf_model.rqSpline import MaskedCouplingRQSpline
67
from flowMC.resource.optimizer import Optimizer
@@ -78,22 +79,19 @@ def loss_fn(params: Float[Array, " n_dim"], data: dict = {}) -> Float:
7879

7980

8081
class TestLocalStep:
81-
def test_take_local_step(self):
82+
@pytest.fixture(autouse=True)
83+
def setup(self):
8284
n_chains = 5
8385
n_steps = 25
8486
n_dims = 2
8587
n_batch = 5
86-
8788
test_position = Buffer("test_position", (n_chains, n_steps, n_dims), 1)
8889
test_log_prob = Buffer("test_log_prob", (n_chains, n_steps), 1)
8990
test_acceptance = Buffer("test_acceptance", (n_chains, n_steps), 1)
90-
9191
mala_kernel = MALA(1.0)
9292
grw_kernel = GaussianRandomWalk(1.0)
9393
hmc_kernel = HMC(jnp.eye(n_dims), 0.1, 10)
94-
9594
logpdf = LogPDF(log_posterior, n_dims=n_dims)
96-
9795
sampler_state = State(
9896
{
9997
"test_position": "test_position",
@@ -102,8 +100,10 @@ def test_take_local_step(self):
102100
},
103101
name="sampler_state",
104102
)
105-
106-
resources = {
103+
self.n_batch = n_batch
104+
self.n_dims = n_dims
105+
self.test_position = test_position
106+
self.resources = {
107107
"test_position": test_position,
108108
"test_log_prob": test_log_prob,
109109
"test_acceptance": test_acceptance,
@@ -114,51 +114,74 @@ def test_take_local_step(self):
114114
"sampler_state": sampler_state,
115115
}
116116

117+
def test_take_local_step(self):
117118
strategy = TakeSerialSteps(
118119
"logpdf",
119120
"MALA",
120121
"sampler_state",
121122
["test_position", "test_log_prob", "test_acceptance"],
122-
n_batch,
123+
self.n_batch,
123124
)
124125
key = jax.random.PRNGKey(42)
125-
positions = test_position.data[:, 0]
126-
127-
for i in range(n_batch):
126+
positions = self.test_position.data[:, 0]
127+
for _ in range(self.n_batch):
128128
key, subkey1, subkey2 = jax.random.split(key, 3)
129-
_, resources, positions = strategy(
129+
_, self.resources, positions = strategy(
130130
rng_key=subkey1,
131-
resources=resources,
131+
resources=self.resources,
132132
initial_position=positions,
133-
data={"data": jnp.arange(n_dims)},
133+
data={"data": jnp.arange(self.n_dims)},
134134
)
135-
136135
key, subkey1, subkey2 = jax.random.split(key, 3)
137136
strategy.set_current_position(0)
138-
_, resources, positions = strategy(
137+
_, self.resources, positions = strategy(
139138
rng_key=subkey1,
140-
resources=resources,
139+
resources=self.resources,
141140
initial_position=positions,
142-
data={"data": jnp.arange(n_dims)},
141+
data={"data": jnp.arange(self.n_dims)},
143142
)
144-
145143
key, subkey1, subkey2 = jax.random.split(key, 3)
146144
strategy.kernel_name = "GRW"
147145
strategy.set_current_position(0)
148-
_, resources, positions = strategy(
146+
_, self.resources, positions = strategy(
149147
rng_key=subkey1,
150-
resources=resources,
148+
resources=self.resources,
151149
initial_position=positions,
152-
data={"data": jnp.arange(n_dims)},
150+
data={"data": jnp.arange(self.n_dims)},
153151
)
154-
155152
strategy.kernel_name = "HMC"
156-
_, resources, positions = strategy(
153+
_, self.resources, positions = strategy(
157154
rng_key=subkey1,
158-
resources=resources,
155+
resources=self.resources,
159156
initial_position=positions,
160-
data={"data": jnp.arange(n_dims)},
157+
data={"data": jnp.arange(self.n_dims)},
158+
)
159+
160+
def test_take_local_step_chain_batch_size(self):
161+
# Use a chain_batch_size smaller than the number of chains to trigger batching logic
162+
chain_batch_size = 2
163+
strategy = TakeSerialSteps(
164+
"logpdf",
165+
"MALA",
166+
"sampler_state",
167+
["test_position", "test_log_prob", "test_acceptance"],
168+
self.n_batch,
169+
chain_batch_size=chain_batch_size,
170+
)
171+
key = jax.random.PRNGKey(42)
172+
positions = self.test_position.data[:, 0]
173+
# Run the strategy, which should use batching internally
174+
_, _, final_positions = strategy(
175+
rng_key=key,
176+
resources=self.resources,
177+
initial_position=positions,
178+
data={"data": jnp.arange(self.n_dims)},
161179
)
180+
# Check that the output shape is correct
181+
assert final_positions.shape == (positions.shape[0], positions.shape[1])
182+
# Optionally, check that the buffer was updated for all chains
183+
assert isinstance(test_position := self.resources["test_position"], Buffer)
184+
assert test_position.data.shape[0] == positions.shape[0]
162185

163186

164187
class TestNFStrategies:

0 commit comments

Comments
 (0)