|
11 | 11 | from copy import deepcopy
|
12 | 12 |
|
13 | 13 | from functools import partial
|
14 |
| -from itertools import count |
| 14 | +from itertools import count, product |
15 | 15 | from typing import Any
|
16 | 16 | from unittest.mock import patch
|
17 | 17 |
|
@@ -179,6 +179,29 @@ def _estimator(samples, bounds):
|
179 | 179 |
|
180 | 180 | self.assertAllClose(est, prob, rtol=0, atol=atol)
|
181 | 181 |
|
| 182 | + def test_solve_batch(self): |
| 183 | + ndim = 3 |
| 184 | + batch_shape = (3, 4) |
| 185 | + with torch.random.fork_rng(): |
| 186 | + torch.random.manual_seed(next(self.seed_generator)) |
| 187 | + bounds = self.gen_bounds(ndim, batch_shape, bound_range=(-5.0, +5.0)) |
| 188 | + sqrt_cov = self.gen_covariances(ndim, batch_shape, as_sqrt=True) |
| 189 | + |
| 190 | + cov = sqrt_cov @ sqrt_cov.mT |
| 191 | + |
| 192 | + batched_solver = MVNXPB(cov, bounds) |
| 193 | + batched_solver.solve() |
| 194 | + |
| 195 | + # solution for each individual batch element is the same as |
| 196 | + # that of the entire batch |
| 197 | + for idx in product(*map(range, batch_shape)): |
| 198 | + solver = MVNXPB(cov[tuple(idx)], bounds[tuple(idx)]) |
| 199 | + solver.solve() |
| 200 | + self.assertAlmostEqual( |
| 201 | + batched_solver.log_prob[tuple(idx)].item(), |
| 202 | + solver.log_prob.item(), |
| 203 | + ) |
| 204 | + |
182 | 205 | def test_augment(self):
|
183 | 206 | r"""Test `augment`."""
|
184 | 207 | with torch.random.fork_rng():
|
|
0 commit comments