Skip to content

Commit cbada4c

Browse files
ltiaofacebook-github-bot
authored andcommitted
Fix batch computation in Pivoted Cholesky (#2823)
Summary: Pull Request resolved: #2823 ## Context Resolves issue #2819 where `PivotedCholesky.update_` break when there is more than a single batch dimension. ## Changes Updates a line to extend boolean indexing logic to cases where `len(batch_shape) > 1` Reviewed By: saitcakmak Differential Revision: D72906531 fbshipit-source-id: 6d8088a96aaa3e9e1d6a799f1c6c9db617c1b9eb
1 parent 77a2867 commit cbada4c

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

botorch/utils/probability/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def update_(self, eps: float = 1e-10) -> None:
125125
rank1 = L[..., i + 1 :, i : i + 1].clone()
126126
rank1 = (rank1 * rank1.transpose(-1, -2)).tril()
127127
L[..., i + 1 :, i + 1 :] = L[..., i + 1 :, i + 1 :].clone() - rank1
128-
L[Lii <= i * eps, i:, i] = 0 # numerical stability clause
128+
L[..., i:, i][Lii <= i * eps] = 0 # numerical stability clause
129129
self.step += 1
130130

131131
def pivot_(self, pivot: LongTensor) -> None:

test/utils/probability/test_mvnxpb.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from copy import deepcopy
1212

1313
from functools import partial
14-
from itertools import count
14+
from itertools import count, product
1515
from typing import Any
1616
from unittest.mock import patch
1717

@@ -179,6 +179,29 @@ def _estimator(samples, bounds):
179179

180180
self.assertAllClose(est, prob, rtol=0, atol=atol)
181181

182+
def test_solve_batch(self):
183+
ndim = 3
184+
batch_shape = (3, 4)
185+
with torch.random.fork_rng():
186+
torch.random.manual_seed(next(self.seed_generator))
187+
bounds = self.gen_bounds(ndim, batch_shape, bound_range=(-5.0, +5.0))
188+
sqrt_cov = self.gen_covariances(ndim, batch_shape, as_sqrt=True)
189+
190+
cov = sqrt_cov @ sqrt_cov.mT
191+
192+
batched_solver = MVNXPB(cov, bounds)
193+
batched_solver.solve()
194+
195+
# solution for each individual batch element is the same as
196+
# that of the entire batch
197+
for idx in product(*map(range, batch_shape)):
198+
solver = MVNXPB(cov[tuple(idx)], bounds[tuple(idx)])
199+
solver.solve()
200+
self.assertAlmostEqual(
201+
batched_solver.log_prob[tuple(idx)].item(),
202+
solver.log_prob.item(),
203+
)
204+
182205
def test_augment(self):
183206
r"""Test `augment`."""
184207
with torch.random.fork_rng():

0 commit comments

Comments
 (0)