Skip to content

Commit de51252

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add contiguous call to fix failing tests (#2831)
Summary: Pull Request resolved: #2831 There is a sneaky upstream bug (pytorch/pytorch#151978) that caused some unit tests to fail and broke max posterior sampling. This contiguous call resolves it for now, until the upstream fix goes in. Reviewed By: Balandat Differential Revision: D73484645 fbshipit-source-id: 3daa91036a6db72c23046e9723a8a5726e71574a
1 parent 5180eed commit de51252

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

botorch/generation/sampling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,10 @@ def maximize_samples(self, X: Tensor, samples: Tensor, num_samples: int = 1):
145145
# to have shape batch_shape x num_samples
146146
if idcs.ndim > 1:
147147
idcs = idcs.permute(*range(1, idcs.ndim), 0)
148-
# in order to use gather, we need to repeat the index tensor d times
149-
idcs = idcs.unsqueeze(-1).expand(*idcs.shape, X.size(-1))
148+
# in order to use gather, we need to repeat the index tensor d times.
149+
# The contiguous call is needed due to a pytorch issue:
150+
# https://github.com/pytorch/pytorch/issues/151978
151+
idcs = idcs.unsqueeze(-1).expand(*idcs.shape, X.size(-1)).contiguous()
150152
# now if the model is batched batch_shape will not necessarily be the
151153
# batch_shape of X, so we expand X to the proper shape
152154
Xe = X.expand(*obj.shape[1:], X.size(-1))

0 commit comments

Comments
 (0)