Skip to content

Commit 0df5521

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Support mixing SAAS & SingleTaskGP models in ModelListGP (#2693)
Summary: Pull Request resolved: #2693 Adds support for broadcasting MVNs produced by the underlying models to enable mixing together SAAS & SingleTaskGP models within a ModelListGP. Reviewed By: sdaulton Differential Revision: D68503063 fbshipit-source-id: bffca2887fd5cf0f00c3503b3958acc367724941
1 parent 7b803bd commit 0df5521

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

botorch/models/gpytorch.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,7 @@ def posterior(
719719
interleaved=False,
720720
)
721721
else:
722+
mvns = self._broadcast_mvns(mvns=mvns)
722723
mvn = (
723724
mvns[0]
724725
if len(mvns) == 1
@@ -738,6 +739,38 @@ def posterior(
738739
def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model:
739740
raise NotImplementedError()
740741

742+
def _broadcast_mvns(self, mvns: list[MultivariateNormal]) -> MultivariateNormal:
743+
"""Broadcasts the batch shapes of the given MultivariateNormals.
744+
745+
The MVNs will have a batch shape of `input_batch_shape x model_batch_shape`.
746+
If the model batch shapes are broadcastable, we will broadcast the mvns to
747+
a batch shape of `input_batch_shape x self.batch_shape`.
748+
749+
Args:
750+
mvns: A list of MultivariateNormals.
751+
752+
Returns:
753+
A list of MultivariateNormals with broadcasted batch shapes.
754+
"""
755+
mvn_batch_shapes = {mvn.batch_shape for mvn in mvns}
756+
if len(mvn_batch_shapes) == 1:
757+
# All MVNs have the same batch shape. We can return as is.
758+
return mvns
759+
# This call will error out if they're not broadcastable.
760+
# If they're broadcastable, it'll log a warning.
761+
target_model_shape = self.batch_shape
762+
max_batch = max(mvn_batch_shapes, key=len)
763+
max_len = len(max_batch)
764+
input_batch_len = max_len - len(target_model_shape)
765+
for i in range(len(mvns)): # Loop over index since we modify contents.
766+
while len(mvns[i].batch_shape) < max_len:
767+
# MVN is missing batch dimensions. Unsqueeze as needed.
768+
mvns[i] = mvns[i].unsqueeze(input_batch_len)
769+
if mvns[i].batch_shape != max_batch:
770+
# Expand to match the batch shapes.
771+
mvns[i] = mvns[i].expand(max_batch)
772+
return mvns
773+
741774

742775
class MultiTaskGPyTorchModel(GPyTorchModel, ABC):
743776
r"""Abstract base class for multi-task models based on GPyTorch models.

test/models/test_model_list_gp_regression.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from botorch.acquisition.objective import ScalarizedPosteriorTransform
1313
from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError
1414
from botorch.exceptions.warnings import OptimizationWarning
15-
from botorch.fit import fit_gpytorch_mll
15+
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
16+
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
1617
from botorch.models.gp_regression import SingleTaskGP
1718
from botorch.models.model_list_gp_regression import ModelListGP
1819
from botorch.models.multitask import MultiTaskGP
@@ -733,3 +734,29 @@ def test_fantasize_with_outcome_transform_fixed_noise(self) -> None:
733734
self.assertTrue(
734735
torch.equal(fm_i.train_inputs[0][0][-1], X[1 - i])
735736
)
737+
738+
def test_with_different_batch_shapes(self) -> None:
739+
# Tests that we can mix single task and SAAS models together.
740+
tkwargs = {"device": self.device, "dtype": torch.double}
741+
m1 = SaasFullyBayesianSingleTaskGP(
742+
train_X=torch.rand(10, 2, **tkwargs), train_Y=torch.rand(10, 1, **tkwargs)
743+
)
744+
fit_fully_bayesian_model_nuts(m1, warmup_steps=0, num_samples=8, thinning=1)
745+
m2 = SingleTaskGP(
746+
train_X=torch.rand(10, 2, **tkwargs), train_Y=torch.rand(10, 1, **tkwargs)
747+
)
748+
m = ModelListGP(m1, m2)
749+
with self.assertWarnsRegex(UserWarning, "Component models of"):
750+
self.assertEqual(m.batch_shape, torch.Size([8]))
751+
# Non-batched evaluation.
752+
with self.assertWarnsRegex(UserWarning, "Component models of"):
753+
post = m.posterior(torch.rand(1, 2, **tkwargs))
754+
self.assertEqual(post.batch_shape, torch.Size([8]))
755+
self.assertEqual(post.rsample(torch.Size([2])).shape, torch.Size([2, 8, 1, 2]))
756+
# Batched evaluation.
757+
with self.assertWarnsRegex(UserWarning, "Component models of"):
758+
post = m.posterior(torch.rand(5, 1, 2, **tkwargs))
759+
self.assertEqual(post.batch_shape, torch.Size([5, 8]))
760+
self.assertEqual(
761+
post.rsample(torch.Size([2])).shape, torch.Size([2, 5, 8, 1, 2])
762+
)

0 commit comments

Comments
 (0)